first commit
This commit is contained in:
1
raymarching/__init__.py
Normal file
1
raymarching/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .raymarching import *
|
||||
41
raymarching/backend.py
Normal file
41
raymarching/backend.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import os
|
||||
from torch.utils.cpp_extension import load
|
||||
|
||||
_src_path = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
nvcc_flags = [
|
||||
'-O3', '-std=c++14',
|
||||
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
|
||||
]
|
||||
|
||||
if os.name == "posix":
|
||||
c_flags = ['-O3', '-std=c++14']
|
||||
elif os.name == "nt":
|
||||
c_flags = ['/O2', '/std:c++17']
|
||||
|
||||
# find cl.exe
|
||||
def find_cl_path():
|
||||
import glob
|
||||
for program_files in [r"C:\\Program Files (x86)", r"C:\\Program Files"]:
|
||||
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
|
||||
paths = sorted(glob.glob(r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % (program_files, edition)), reverse=True)
|
||||
if paths:
|
||||
return paths[0]
|
||||
|
||||
# If cl.exe is not on path, try to find it.
|
||||
if os.system("where cl.exe >nul 2>nul") != 0:
|
||||
cl_path = find_cl_path()
|
||||
if cl_path is None:
|
||||
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
|
||||
os.environ["PATH"] += ";" + cl_path
|
||||
|
||||
_backend = load(name='_raymarching',
|
||||
extra_cflags=c_flags,
|
||||
extra_cuda_cflags=nvcc_flags,
|
||||
sources=[os.path.join(_src_path, 'src', f) for f in [
|
||||
'raymarching.cu',
|
||||
'bindings.cpp',
|
||||
]],
|
||||
)
|
||||
|
||||
__all__ = ['_backend']
|
||||
398
raymarching/raymarching.py
Normal file
398
raymarching/raymarching.py
Normal file
@@ -0,0 +1,398 @@
|
||||
import numpy as np
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.autograd import Function
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
|
||||
# lazy building:
|
||||
# `import raymarching` will not immediately build the extension, only if you actually call any functions.
|
||||
|
||||
BACKEND = None
|
||||
|
||||
def get_backend():
|
||||
global BACKEND
|
||||
|
||||
if BACKEND is None:
|
||||
try:
|
||||
import _raymarching as _backend
|
||||
except ImportError:
|
||||
from .backend import _backend
|
||||
|
||||
BACKEND = _backend
|
||||
|
||||
return BACKEND
|
||||
|
||||
# ----------------------------------------
|
||||
# utils
|
||||
# ----------------------------------------
|
||||
|
||||
class _near_far_from_aabb(Function):
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float32)
|
||||
def forward(ctx, rays_o, rays_d, aabb, min_near=0.2):
|
||||
''' near_far_from_aabb, CUDA implementation
|
||||
Calculate rays' intersection time (near and far) with aabb
|
||||
Args:
|
||||
rays_o: float, [N, 3]
|
||||
rays_d: float, [N, 3]
|
||||
aabb: float, [6], (xmin, ymin, zmin, xmax, ymax, zmax)
|
||||
min_near: float, scalar
|
||||
Returns:
|
||||
nears: float, [N]
|
||||
fars: float, [N]
|
||||
'''
|
||||
if not rays_o.is_cuda: rays_o = rays_o.cuda()
|
||||
if not rays_d.is_cuda: rays_d = rays_d.cuda()
|
||||
|
||||
rays_o = rays_o.contiguous().view(-1, 3)
|
||||
rays_d = rays_d.contiguous().view(-1, 3)
|
||||
|
||||
N = rays_o.shape[0] # num rays
|
||||
|
||||
nears = torch.empty(N, dtype=rays_o.dtype, device=rays_o.device)
|
||||
fars = torch.empty(N, dtype=rays_o.dtype, device=rays_o.device)
|
||||
|
||||
get_backend().near_far_from_aabb(rays_o, rays_d, aabb, N, min_near, nears, fars)
|
||||
|
||||
return nears, fars
|
||||
|
||||
near_far_from_aabb = _near_far_from_aabb.apply
|
||||
|
||||
|
||||
class _sph_from_ray(Function):
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float32)
|
||||
def forward(ctx, rays_o, rays_d, radius):
|
||||
''' sph_from_ray, CUDA implementation
|
||||
get spherical coordinate on the background sphere from rays.
|
||||
Assume rays_o are inside the Sphere(radius).
|
||||
Args:
|
||||
rays_o: [N, 3]
|
||||
rays_d: [N, 3]
|
||||
radius: scalar, float
|
||||
Return:
|
||||
coords: [N, 2], in [-1, 1], theta and phi on a sphere. (further-surface)
|
||||
'''
|
||||
if not rays_o.is_cuda: rays_o = rays_o.cuda()
|
||||
if not rays_d.is_cuda: rays_d = rays_d.cuda()
|
||||
|
||||
rays_o = rays_o.contiguous().view(-1, 3)
|
||||
rays_d = rays_d.contiguous().view(-1, 3)
|
||||
|
||||
N = rays_o.shape[0] # num rays
|
||||
|
||||
coords = torch.empty(N, 2, dtype=rays_o.dtype, device=rays_o.device)
|
||||
|
||||
get_backend().sph_from_ray(rays_o, rays_d, radius, N, coords)
|
||||
|
||||
return coords
|
||||
|
||||
sph_from_ray = _sph_from_ray.apply
|
||||
|
||||
|
||||
class _morton3D(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, coords):
|
||||
''' morton3D, CUDA implementation
|
||||
Args:
|
||||
coords: [N, 3], int32, in [0, 128) (for some reason there is no uint32 tensor in torch...)
|
||||
TODO: check if the coord range is valid! (current 128 is safe)
|
||||
Returns:
|
||||
indices: [N], int32, in [0, 128^3)
|
||||
|
||||
'''
|
||||
if not coords.is_cuda: coords = coords.cuda()
|
||||
|
||||
N = coords.shape[0]
|
||||
|
||||
indices = torch.empty(N, dtype=torch.int32, device=coords.device)
|
||||
|
||||
get_backend().morton3D(coords.int(), N, indices)
|
||||
|
||||
return indices
|
||||
|
||||
morton3D = _morton3D.apply
|
||||
|
||||
class _morton3D_invert(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, indices):
|
||||
''' morton3D_invert, CUDA implementation
|
||||
Args:
|
||||
indices: [N], int32, in [0, 128^3)
|
||||
Returns:
|
||||
coords: [N, 3], int32, in [0, 128)
|
||||
|
||||
'''
|
||||
if not indices.is_cuda: indices = indices.cuda()
|
||||
|
||||
N = indices.shape[0]
|
||||
|
||||
coords = torch.empty(N, 3, dtype=torch.int32, device=indices.device)
|
||||
|
||||
get_backend().morton3D_invert(indices.int(), N, coords)
|
||||
|
||||
return coords
|
||||
|
||||
morton3D_invert = _morton3D_invert.apply
|
||||
|
||||
|
||||
class _packbits(Function):
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float32)
|
||||
def forward(ctx, grid, thresh, bitfield=None):
|
||||
''' packbits, CUDA implementation
|
||||
Pack up the density grid into a bit field to accelerate ray marching.
|
||||
Args:
|
||||
grid: float, [C, H * H * H], assume H % 2 == 0
|
||||
thresh: float, threshold
|
||||
Returns:
|
||||
bitfield: uint8, [C, H * H * H / 8]
|
||||
'''
|
||||
if not grid.is_cuda: grid = grid.cuda()
|
||||
grid = grid.contiguous()
|
||||
|
||||
C = grid.shape[0]
|
||||
H3 = grid.shape[1]
|
||||
N = C * H3 // 8
|
||||
|
||||
if bitfield is None:
|
||||
bitfield = torch.empty(N, dtype=torch.uint8, device=grid.device)
|
||||
|
||||
get_backend().packbits(grid, N, thresh, bitfield)
|
||||
|
||||
return bitfield
|
||||
|
||||
packbits = _packbits.apply
|
||||
|
||||
|
||||
class _flatten_rays(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, rays, M):
|
||||
''' flatten rays
|
||||
Args:
|
||||
rays: [N, 2], all rays' (point_offset, point_count),
|
||||
M: scalar, int, count of points (we cannot get this info from rays unfortunately...)
|
||||
Returns:
|
||||
res: [M], flattened ray index.
|
||||
'''
|
||||
if not rays.is_cuda: rays = rays.cuda()
|
||||
rays = rays.contiguous()
|
||||
|
||||
N = rays.shape[0]
|
||||
|
||||
res = torch.zeros(M, dtype=torch.int, device=rays.device)
|
||||
|
||||
get_backend().flatten_rays(rays, N, M, res)
|
||||
|
||||
return res
|
||||
|
||||
flatten_rays = _flatten_rays.apply
|
||||
|
||||
# ----------------------------------------
|
||||
# train functions
|
||||
# ----------------------------------------
|
||||
|
||||
class _march_rays_train(Function):
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float32)
|
||||
def forward(ctx, rays_o, rays_d, bound, density_bitfield, C, H, nears, fars, perturb=False, dt_gamma=0, max_steps=1024, contract=False):
|
||||
''' march rays to generate points (forward only)
|
||||
Args:
|
||||
rays_o/d: float, [N, 3]
|
||||
bound: float, scalar
|
||||
density_bitfield: uint8: [CHHH // 8]
|
||||
C: int
|
||||
H: int
|
||||
nears/fars: float, [N]
|
||||
step_counter: int32, (2), used to count the actual number of generated points.
|
||||
mean_count: int32, estimated mean steps to accelerate training. (but will randomly drop rays if the actual point count exceeded this threshold.)
|
||||
perturb: bool
|
||||
align: int, pad output so its size is dividable by align, set to -1 to disable.
|
||||
force_all_rays: bool, ignore step_counter and mean_count, always calculate all rays. Useful if rendering the whole image, instead of some rays.
|
||||
dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance)
|
||||
max_steps: int, max number of sampled points along each ray, also affect min_stepsize.
|
||||
Returns:
|
||||
xyzs: float, [M, 3], all generated points' coords. (all rays concated, need to use `rays` to extract points belonging to each ray)
|
||||
dirs: float, [M, 3], all generated points' view dirs.
|
||||
ts: float, [M, 2], all generated points' ts.
|
||||
rays: int32, [N, 2], all rays' (point_offset, point_count), e.g., xyzs[rays[i, 0]:(rays[i, 0] + rays[i, 1])] --> points belonging to rays[i, 0]
|
||||
'''
|
||||
|
||||
if not rays_o.is_cuda: rays_o = rays_o.cuda()
|
||||
if not rays_d.is_cuda: rays_d = rays_d.cuda()
|
||||
if not density_bitfield.is_cuda: density_bitfield = density_bitfield.cuda()
|
||||
|
||||
rays_o = rays_o.float().contiguous().view(-1, 3)
|
||||
rays_d = rays_d.float().contiguous().view(-1, 3)
|
||||
density_bitfield = density_bitfield.contiguous()
|
||||
|
||||
N = rays_o.shape[0] # num rays
|
||||
|
||||
step_counter = torch.zeros(1, dtype=torch.int32, device=rays_o.device) # point counter, ray counter
|
||||
|
||||
if perturb:
|
||||
noises = torch.rand(N, dtype=rays_o.dtype, device=rays_o.device)
|
||||
else:
|
||||
noises = torch.zeros(N, dtype=rays_o.dtype, device=rays_o.device)
|
||||
|
||||
# first pass: write rays, get total number of points M to render
|
||||
rays = torch.empty(N, 2, dtype=torch.int32, device=rays_o.device) # id, offset, num_steps
|
||||
get_backend().march_rays_train(rays_o, rays_d, density_bitfield, bound, contract, dt_gamma, max_steps, N, C, H, nears, fars, None, None, None, rays, step_counter, noises)
|
||||
|
||||
# allocate based on M
|
||||
M = step_counter.item()
|
||||
# print(M, N)
|
||||
# print(rays[:, 0].max())
|
||||
|
||||
xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)
|
||||
dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)
|
||||
ts = torch.zeros(M, 2, dtype=rays_o.dtype, device=rays_o.device)
|
||||
|
||||
# second pass: write outputs
|
||||
get_backend().march_rays_train(rays_o, rays_d, density_bitfield, bound, contract, dt_gamma, max_steps, N, C, H, nears, fars, xyzs, dirs, ts, rays, step_counter, noises)
|
||||
|
||||
return xyzs, dirs, ts, rays
|
||||
|
||||
march_rays_train = _march_rays_train.apply
|
||||
|
||||
|
||||
class _composite_rays_train(Function):
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float32)
|
||||
def forward(ctx, sigmas, rgbs, ts, rays, T_thresh=1e-4, binarize=False):
|
||||
''' composite rays' rgbs, according to the ray marching formula.
|
||||
Args:
|
||||
rgbs: float, [M, 3]
|
||||
sigmas: float, [M,]
|
||||
ts: float, [M, 2]
|
||||
rays: int32, [N, 3]
|
||||
Returns:
|
||||
weights: float, [M]
|
||||
weights_sum: float, [N,], the alpha channel
|
||||
depth: float, [N, ], the Depth
|
||||
image: float, [N, 3], the RGB channel (after multiplying alpha!)
|
||||
'''
|
||||
|
||||
sigmas = sigmas.float().contiguous()
|
||||
rgbs = rgbs.float().contiguous()
|
||||
|
||||
M = sigmas.shape[0]
|
||||
N = rays.shape[0]
|
||||
|
||||
weights = torch.zeros(M, dtype=sigmas.dtype, device=sigmas.device) # may leave unmodified, so init with 0
|
||||
weights_sum = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
|
||||
|
||||
depth = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
|
||||
image = torch.empty(N, 3, dtype=sigmas.dtype, device=sigmas.device)
|
||||
|
||||
get_backend().composite_rays_train_forward(sigmas, rgbs, ts, rays, M, N, T_thresh, binarize, weights, weights_sum, depth, image)
|
||||
|
||||
ctx.save_for_backward(sigmas, rgbs, ts, rays, weights_sum, depth, image)
|
||||
ctx.dims = [M, N, T_thresh, binarize]
|
||||
|
||||
return weights, weights_sum, depth, image
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, grad_weights, grad_weights_sum, grad_depth, grad_image):
|
||||
|
||||
grad_weights = grad_weights.contiguous()
|
||||
grad_weights_sum = grad_weights_sum.contiguous()
|
||||
grad_depth = grad_depth.contiguous()
|
||||
grad_image = grad_image.contiguous()
|
||||
|
||||
sigmas, rgbs, ts, rays, weights_sum, depth, image = ctx.saved_tensors
|
||||
M, N, T_thresh, binarize = ctx.dims
|
||||
|
||||
grad_sigmas = torch.zeros_like(sigmas)
|
||||
grad_rgbs = torch.zeros_like(rgbs)
|
||||
|
||||
get_backend().composite_rays_train_backward(grad_weights, grad_weights_sum, grad_depth, grad_image, sigmas, rgbs, ts, rays, weights_sum, depth, image, M, N, T_thresh, binarize, grad_sigmas, grad_rgbs)
|
||||
|
||||
return grad_sigmas, grad_rgbs, None, None, None, None
|
||||
|
||||
|
||||
composite_rays_train = _composite_rays_train.apply
|
||||
|
||||
# ----------------------------------------
|
||||
# infer functions
|
||||
# ----------------------------------------
|
||||
|
||||
class _march_rays(Function):
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float32)
|
||||
def forward(ctx, n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, bound, density_bitfield, C, H, near, far, perturb=False, dt_gamma=0, max_steps=1024, contract=False):
|
||||
''' march rays to generate points (forward only, for inference)
|
||||
Args:
|
||||
n_alive: int, number of alive rays
|
||||
n_step: int, how many steps we march
|
||||
rays_alive: int, [N], the alive rays' IDs in N (N >= n_alive, but we only use first n_alive)
|
||||
rays_t: float, [N], the alive rays' time, we only use the first n_alive.
|
||||
rays_o/d: float, [N, 3]
|
||||
bound: float, scalar
|
||||
density_bitfield: uint8: [CHHH // 8]
|
||||
C: int
|
||||
H: int
|
||||
nears/fars: float, [N]
|
||||
align: int, pad output so its size is dividable by align, set to -1 to disable.
|
||||
perturb: bool/int, int > 0 is used as the random seed.
|
||||
dt_gamma: float, called cone_angle in instant-ngp, exponentially accelerate ray marching if > 0. (very significant effect, but generally lead to worse performance)
|
||||
max_steps: int, max number of sampled points along each ray, also affect min_stepsize.
|
||||
Returns:
|
||||
xyzs: float, [n_alive * n_step, 3], all generated points' coords
|
||||
dirs: float, [n_alive * n_step, 3], all generated points' view dirs.
|
||||
ts: float, [n_alive * n_step, 2], all generated points' ts
|
||||
'''
|
||||
|
||||
if not rays_o.is_cuda: rays_o = rays_o.cuda()
|
||||
if not rays_d.is_cuda: rays_d = rays_d.cuda()
|
||||
|
||||
rays_o = rays_o.float().contiguous().view(-1, 3)
|
||||
rays_d = rays_d.float().contiguous().view(-1, 3)
|
||||
|
||||
M = n_alive * n_step
|
||||
|
||||
xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)
|
||||
dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)
|
||||
ts = torch.zeros(M, 2, dtype=rays_o.dtype, device=rays_o.device) # 2 vals, one for rgb, one for depth
|
||||
|
||||
if perturb:
|
||||
# torch.manual_seed(perturb) # test_gui uses spp index as seed
|
||||
noises = torch.rand(n_alive, dtype=rays_o.dtype, device=rays_o.device)
|
||||
else:
|
||||
noises = torch.zeros(n_alive, dtype=rays_o.dtype, device=rays_o.device)
|
||||
|
||||
get_backend().march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, bound, contract, dt_gamma, max_steps, C, H, density_bitfield, near, far, xyzs, dirs, ts, noises)
|
||||
|
||||
return xyzs, dirs, ts
|
||||
|
||||
march_rays = _march_rays.apply
|
||||
|
||||
|
||||
class _composite_rays(Function):
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float32) # need to cast sigmas & rgbs to float
|
||||
def forward(ctx, n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, ts, weights_sum, depth, image, T_thresh=1e-2, binarize=False):
|
||||
''' composite rays' rgbs, according to the ray marching formula. (for inference)
|
||||
Args:
|
||||
n_alive: int, number of alive rays
|
||||
n_step: int, how many steps we march
|
||||
rays_alive: int, [n_alive], the alive rays' IDs in N (N >= n_alive)
|
||||
rays_t: float, [N], the alive rays' time
|
||||
sigmas: float, [n_alive * n_step,]
|
||||
rgbs: float, [n_alive * n_step, 3]
|
||||
ts: float, [n_alive * n_step, 2]
|
||||
In-place Outputs:
|
||||
weights_sum: float, [N,], the alpha channel
|
||||
depth: float, [N,], the depth value
|
||||
image: float, [N, 3], the RGB channel (after multiplying alpha!)
|
||||
'''
|
||||
sigmas = sigmas.float().contiguous()
|
||||
rgbs = rgbs.float().contiguous()
|
||||
get_backend().composite_rays(n_alive, n_step, T_thresh, binarize, rays_alive, rays_t, sigmas, rgbs, ts, weights_sum, depth, image)
|
||||
return tuple()
|
||||
|
||||
|
||||
composite_rays = _composite_rays.apply
|
||||
63
raymarching/setup.py
Normal file
63
raymarching/setup.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import os
|
||||
from setuptools import setup
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
|
||||
_src_path = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
nvcc_flags = [
|
||||
'-O3', '-std=c++14',
|
||||
'-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF2_OPERATORS__',
|
||||
]
|
||||
|
||||
if os.name == "posix":
|
||||
c_flags = ['-O3', '-std=c++14']
|
||||
elif os.name == "nt":
|
||||
c_flags = ['/O2', '/std:c++17']
|
||||
|
||||
# find cl.exe
|
||||
def find_cl_path():
|
||||
import glob
|
||||
for program_files in [r"C:\\Program Files (x86)", r"C:\\Program Files"]:
|
||||
for edition in ["Enterprise", "Professional", "BuildTools", "Community"]:
|
||||
paths = sorted(glob.glob(r"%s\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" % (program_files, edition)), reverse=True)
|
||||
if paths:
|
||||
return paths[0]
|
||||
|
||||
# If cl.exe is not on path, try to find it.
|
||||
if os.system("where cl.exe >nul 2>nul") != 0:
|
||||
cl_path = find_cl_path()
|
||||
if cl_path is None:
|
||||
raise RuntimeError("Could not locate a supported Microsoft Visual C++ installation")
|
||||
os.environ["PATH"] += ";" + cl_path
|
||||
|
||||
'''
|
||||
Usage:
|
||||
|
||||
python setup.py build_ext --inplace # build extensions locally, do not install (only can be used from the parent directory)
|
||||
|
||||
python setup.py install # build extensions and install (copy) to PATH.
|
||||
pip install . # ditto but better (e.g., dependency & metadata handling)
|
||||
|
||||
python setup.py develop # build extensions and install (symbolic) to PATH.
|
||||
pip install -e . # ditto but better (e.g., dependency & metadata handling)
|
||||
|
||||
'''
|
||||
setup(
|
||||
name='raymarching', # package name, import this to use python API
|
||||
ext_modules=[
|
||||
CUDAExtension(
|
||||
name='_raymarching', # extension name, import this to use CUDA API
|
||||
sources=[os.path.join(_src_path, 'src', f) for f in [
|
||||
'raymarching.cu',
|
||||
'bindings.cpp',
|
||||
]],
|
||||
extra_compile_args={
|
||||
'cxx': c_flags,
|
||||
'nvcc': nvcc_flags,
|
||||
}
|
||||
),
|
||||
],
|
||||
cmdclass={
|
||||
'build_ext': BuildExtension,
|
||||
}
|
||||
)
|
||||
20
raymarching/src/bindings.cpp
Normal file
20
raymarching/src/bindings.cpp
Normal file
@@ -0,0 +1,20 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include "raymarching.h"
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
// utils
|
||||
m.def("flatten_rays", &flatten_rays, "flatten_rays (CUDA)");
|
||||
m.def("packbits", &packbits, "packbits (CUDA)");
|
||||
m.def("near_far_from_aabb", &near_far_from_aabb, "near_far_from_aabb (CUDA)");
|
||||
m.def("sph_from_ray", &sph_from_ray, "sph_from_ray (CUDA)");
|
||||
m.def("morton3D", &morton3D, "morton3D (CUDA)");
|
||||
m.def("morton3D_invert", &morton3D_invert, "morton3D_invert (CUDA)");
|
||||
// train
|
||||
m.def("march_rays_train", &march_rays_train, "march_rays_train (CUDA)");
|
||||
m.def("composite_rays_train_forward", &composite_rays_train_forward, "composite_rays_train_forward (CUDA)");
|
||||
m.def("composite_rays_train_backward", &composite_rays_train_backward, "composite_rays_train_backward (CUDA)");
|
||||
// infer
|
||||
m.def("march_rays", &march_rays, "march rays (CUDA)");
|
||||
m.def("composite_rays", &composite_rays, "composite rays (CUDA)");
|
||||
}
|
||||
934
raymarching/src/raymarching.cu
Normal file
934
raymarching/src/raymarching.cu
Normal file
@@ -0,0 +1,934 @@
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <torch/torch.h>
|
||||
|
||||
#include <cstdio>
|
||||
#include <stdint.h>
|
||||
#include <stdexcept>
|
||||
#include <limits>
|
||||
|
||||
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor")
|
||||
#define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor")
|
||||
#define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor")
|
||||
|
||||
|
||||
inline constexpr __device__ float SQRT3() { return 1.7320508075688772f; }
|
||||
inline constexpr __device__ float RSQRT3() { return 0.5773502691896258f; }
|
||||
inline constexpr __device__ float PI() { return 3.141592653589793f; }
|
||||
inline constexpr __device__ float RPI() { return 0.3183098861837907f; }
|
||||
|
||||
|
||||
template <typename T>
|
||||
inline __host__ __device__ T div_round_up(T val, T divisor) {
|
||||
return (val + divisor - 1) / divisor;
|
||||
}
|
||||
|
||||
inline __host__ __device__ float signf(const float x) {
|
||||
return copysignf(1.0, x);
|
||||
}
|
||||
|
||||
inline __host__ __device__ float clamp(const float x, const float min, const float max) {
|
||||
return fminf(max, fmaxf(min, x));
|
||||
}
|
||||
|
||||
inline __host__ __device__ void swapf(float& a, float& b) {
|
||||
float c = a; a = b; b = c;
|
||||
}
|
||||
|
||||
inline __device__ int mip_from_pos(const float x, const float y, const float z, const float max_cascade) {
|
||||
const float mx = fmaxf(fabsf(x), fmaxf(fabsf(y), fabsf(z)));
|
||||
int exponent;
|
||||
frexpf(mx, &exponent); // [0, 0.5) --> -1, [0.5, 1) --> 0, [1, 2) --> 1, [2, 4) --> 2, ...
|
||||
return fminf(max_cascade - 1, fmaxf(0, exponent));
|
||||
}
|
||||
|
||||
inline __device__ int mip_from_dt(const float dt, const float H, const float max_cascade) {
|
||||
const float mx = dt * H * 0.5;
|
||||
int exponent;
|
||||
frexpf(mx, &exponent);
|
||||
return fminf(max_cascade - 1, fmaxf(0, exponent));
|
||||
}
|
||||
|
||||
inline __host__ __device__ uint32_t __expand_bits(uint32_t v)
|
||||
{
|
||||
v = (v * 0x00010001u) & 0xFF0000FFu;
|
||||
v = (v * 0x00000101u) & 0x0F00F00Fu;
|
||||
v = (v * 0x00000011u) & 0xC30C30C3u;
|
||||
v = (v * 0x00000005u) & 0x49249249u;
|
||||
return v;
|
||||
}
|
||||
|
||||
inline __host__ __device__ uint32_t __morton3D(uint32_t x, uint32_t y, uint32_t z)
|
||||
{
|
||||
uint32_t xx = __expand_bits(x);
|
||||
uint32_t yy = __expand_bits(y);
|
||||
uint32_t zz = __expand_bits(z);
|
||||
return xx | (yy << 1) | (zz << 2);
|
||||
}
|
||||
|
||||
inline __host__ __device__ uint32_t __morton3D_invert(uint32_t x)
|
||||
{
|
||||
x = x & 0x49249249;
|
||||
x = (x | (x >> 2)) & 0xc30c30c3;
|
||||
x = (x | (x >> 4)) & 0x0f00f00f;
|
||||
x = (x | (x >> 8)) & 0xff0000ff;
|
||||
x = (x | (x >> 16)) & 0x0000ffff;
|
||||
return x;
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////
|
||||
///////////// utils /////////////
|
||||
////////////////////////////////////////////////////
|
||||
|
||||
// rays_o/d: [N, 3]
|
||||
// nears/fars: [N]
|
||||
// scalar_t should always be float in use.
|
||||
template <typename scalar_t>
|
||||
__global__ void kernel_near_far_from_aabb(
|
||||
const scalar_t * __restrict__ rays_o,
|
||||
const scalar_t * __restrict__ rays_d,
|
||||
const scalar_t * __restrict__ aabb,
|
||||
const uint32_t N,
|
||||
const float min_near,
|
||||
scalar_t * nears, scalar_t * fars
|
||||
) {
|
||||
// parallel per ray
|
||||
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (n >= N) return;
|
||||
|
||||
// locate
|
||||
rays_o += n * 3;
|
||||
rays_d += n * 3;
|
||||
|
||||
const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];
|
||||
const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];
|
||||
const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;
|
||||
|
||||
// get near far (assume cube scene)
|
||||
float near = (aabb[0] - ox) * rdx;
|
||||
float far = (aabb[3] - ox) * rdx;
|
||||
if (near > far) swapf(near, far);
|
||||
|
||||
float near_y = (aabb[1] - oy) * rdy;
|
||||
float far_y = (aabb[4] - oy) * rdy;
|
||||
if (near_y > far_y) swapf(near_y, far_y);
|
||||
|
||||
if (near > far_y || near_y > far) {
|
||||
nears[n] = fars[n] = std::numeric_limits<scalar_t>::max();
|
||||
return;
|
||||
}
|
||||
|
||||
if (near_y > near) near = near_y;
|
||||
if (far_y < far) far = far_y;
|
||||
|
||||
float near_z = (aabb[2] - oz) * rdz;
|
||||
float far_z = (aabb[5] - oz) * rdz;
|
||||
if (near_z > far_z) swapf(near_z, far_z);
|
||||
|
||||
if (near > far_z || near_z > far) {
|
||||
nears[n] = fars[n] = std::numeric_limits<scalar_t>::max();
|
||||
return;
|
||||
}
|
||||
|
||||
if (near_z > near) near = near_z;
|
||||
if (far_z < far) far = far_z;
|
||||
|
||||
if (near < min_near) near = min_near;
|
||||
|
||||
nears[n] = near;
|
||||
fars[n] = far;
|
||||
}
|
||||
|
||||
|
||||
void near_far_from_aabb(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor aabb, const uint32_t N, const float min_near, at::Tensor nears, at::Tensor fars) {
|
||||
|
||||
static constexpr uint32_t N_THREAD = 128;
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
rays_o.scalar_type(), "near_far_from_aabb", ([&] {
|
||||
kernel_near_far_from_aabb<<<div_round_up(N, N_THREAD), N_THREAD>>>(rays_o.data_ptr<scalar_t>(), rays_d.data_ptr<scalar_t>(), aabb.data_ptr<scalar_t>(), N, min_near, nears.data_ptr<scalar_t>(), fars.data_ptr<scalar_t>());
|
||||
}));
|
||||
}
|
||||
|
||||
|
||||
// rays_o/d: [N, 3]
|
||||
// radius: float
|
||||
// coords: [N, 2]
|
||||
template <typename scalar_t>
|
||||
__global__ void kernel_sph_from_ray(
|
||||
const scalar_t * __restrict__ rays_o,
|
||||
const scalar_t * __restrict__ rays_d,
|
||||
const float radius,
|
||||
const uint32_t N,
|
||||
scalar_t * coords
|
||||
) {
|
||||
// parallel per ray
|
||||
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (n >= N) return;
|
||||
|
||||
// locate
|
||||
rays_o += n * 3;
|
||||
rays_d += n * 3;
|
||||
coords += n * 2;
|
||||
|
||||
const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];
|
||||
const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];
|
||||
// const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;
|
||||
|
||||
// solve t from || o + td || = radius
|
||||
const float A = dx * dx + dy * dy + dz * dz;
|
||||
const float B = ox * dx + oy * dy + oz * dz; // in fact B / 2
|
||||
const float C = ox * ox + oy * oy + oz * oz - radius * radius;
|
||||
|
||||
const float t = (- B + sqrtf(B * B - A * C)) / A; // always use the larger solution (positive)
|
||||
|
||||
// solve theta, phi (assume y is the up axis)
|
||||
const float x = ox + t * dx, y = oy + t * dy, z = oz + t * dz;
|
||||
const float theta = atan2(sqrtf(x * x + z * z), y); // [0, PI)
|
||||
const float phi = atan2(z, x); // [-PI, PI)
|
||||
|
||||
// normalize to [-1, 1]
|
||||
coords[0] = 2 * theta * RPI() - 1;
|
||||
coords[1] = phi * RPI();
|
||||
}
|
||||
|
||||
|
||||
void sph_from_ray(const at::Tensor rays_o, const at::Tensor rays_d, const float radius, const uint32_t N, at::Tensor coords) {
|
||||
|
||||
static constexpr uint32_t N_THREAD = 128;
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
rays_o.scalar_type(), "sph_from_ray", ([&] {
|
||||
kernel_sph_from_ray<<<div_round_up(N, N_THREAD), N_THREAD>>>(rays_o.data_ptr<scalar_t>(), rays_d.data_ptr<scalar_t>(), radius, N, coords.data_ptr<scalar_t>());
|
||||
}));
|
||||
}
|
||||
|
||||
|
||||
// coords: int32, [N, 3]
|
||||
// indices: int32, [N]
|
||||
__global__ void kernel_morton3D(
|
||||
const int * __restrict__ coords,
|
||||
const uint32_t N,
|
||||
int * indices
|
||||
) {
|
||||
// parallel
|
||||
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (n >= N) return;
|
||||
|
||||
// locate
|
||||
coords += n * 3;
|
||||
indices[n] = __morton3D(coords[0], coords[1], coords[2]);
|
||||
}
|
||||
|
||||
|
||||
void morton3D(const at::Tensor coords, const uint32_t N, at::Tensor indices) {
|
||||
static constexpr uint32_t N_THREAD = 128;
|
||||
kernel_morton3D<<<div_round_up(N, N_THREAD), N_THREAD>>>(coords.data_ptr<int>(), N, indices.data_ptr<int>());
|
||||
}
|
||||
|
||||
|
||||
// indices: int32, [N]
|
||||
// coords: int32, [N, 3]
|
||||
__global__ void kernel_morton3D_invert(
|
||||
const int * __restrict__ indices,
|
||||
const uint32_t N,
|
||||
int * coords
|
||||
) {
|
||||
// parallel
|
||||
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (n >= N) return;
|
||||
|
||||
// locate
|
||||
coords += n * 3;
|
||||
|
||||
const int ind = indices[n];
|
||||
|
||||
coords[0] = __morton3D_invert(ind >> 0);
|
||||
coords[1] = __morton3D_invert(ind >> 1);
|
||||
coords[2] = __morton3D_invert(ind >> 2);
|
||||
}
|
||||
|
||||
|
||||
void morton3D_invert(const at::Tensor indices, const uint32_t N, at::Tensor coords) {
|
||||
static constexpr uint32_t N_THREAD = 128;
|
||||
kernel_morton3D_invert<<<div_round_up(N, N_THREAD), N_THREAD>>>(indices.data_ptr<int>(), N, coords.data_ptr<int>());
|
||||
}
|
||||
|
||||
|
||||
// grid: float, [C, H, H, H]
|
||||
// N: int, C * H * H * H / 8
|
||||
// density_thresh: float
|
||||
// bitfield: uint8, [N]
|
||||
template <typename scalar_t>
|
||||
__global__ void kernel_packbits(
|
||||
const scalar_t * __restrict__ grid,
|
||||
const uint32_t N,
|
||||
const float density_thresh,
|
||||
uint8_t * bitfield
|
||||
) {
|
||||
// parallel per byte
|
||||
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (n >= N) return;
|
||||
|
||||
// locate
|
||||
grid += n * 8;
|
||||
|
||||
uint8_t bits = 0;
|
||||
|
||||
#pragma unroll
|
||||
for (uint8_t i = 0; i < 8; i++) {
|
||||
bits |= (grid[i] > density_thresh) ? ((uint8_t)1 << i) : 0;
|
||||
}
|
||||
|
||||
bitfield[n] = bits;
|
||||
}
|
||||
|
||||
|
||||
void packbits(const at::Tensor grid, const uint32_t N, const float density_thresh, at::Tensor bitfield) {
|
||||
|
||||
static constexpr uint32_t N_THREAD = 128;
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
grid.scalar_type(), "packbits", ([&] {
|
||||
kernel_packbits<<<div_round_up(N, N_THREAD), N_THREAD>>>(grid.data_ptr<scalar_t>(), N, density_thresh, bitfield.data_ptr<uint8_t>());
|
||||
}));
|
||||
}
|
||||
|
||||
|
||||
__global__ void kernel_flatten_rays(
|
||||
const int * __restrict__ rays,
|
||||
const uint32_t N, const uint32_t M,
|
||||
int * res
|
||||
) {
|
||||
// parallel per ray
|
||||
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (n >= N) return;
|
||||
|
||||
// locate
|
||||
uint32_t offset = rays[n * 2];
|
||||
uint32_t num_steps = rays[n * 2 + 1];
|
||||
|
||||
// write to res
|
||||
res += offset;
|
||||
for (int i = 0; i < num_steps; i++) res[i] = n;
|
||||
}
|
||||
|
||||
void flatten_rays(const at::Tensor rays, const uint32_t N, const uint32_t M, at::Tensor res) {
|
||||
|
||||
static constexpr uint32_t N_THREAD = 128;
|
||||
|
||||
kernel_flatten_rays<<<div_round_up(N, N_THREAD), N_THREAD>>>(rays.data_ptr<int>(), N, M, res.data_ptr<int>());
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////
|
||||
///////////// training /////////////
|
||||
////////////////////////////////////////////////////
|
||||
|
||||
// rays_o/d: [N, 3]
|
||||
// grid: [CHHH / 8]
|
||||
// xyzs, dirs, ts: [M, 3], [M, 3], [M, 2]
|
||||
// dirs: [M, 3]
|
||||
// rays: [N, 3], idx, offset, num_steps
|
||||
template <typename scalar_t>
|
||||
__global__ void kernel_march_rays_train(
|
||||
const scalar_t * __restrict__ rays_o,
|
||||
const scalar_t * __restrict__ rays_d,
|
||||
const uint8_t * __restrict__ grid,
|
||||
const float bound, const bool contract,
|
||||
const float dt_gamma, const uint32_t max_steps,
|
||||
const uint32_t N, const uint32_t C, const uint32_t H,
|
||||
const scalar_t* __restrict__ nears,
|
||||
const scalar_t* __restrict__ fars,
|
||||
scalar_t * xyzs, scalar_t * dirs, scalar_t * ts,
|
||||
int * rays,
|
||||
int * counter,
|
||||
const scalar_t* __restrict__ noises
|
||||
) {
|
||||
// parallel per ray
|
||||
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (n >= N) return;
|
||||
|
||||
// is first pass running.
|
||||
const bool first_pass = (xyzs == nullptr);
|
||||
|
||||
// locate
|
||||
rays_o += n * 3;
|
||||
rays_d += n * 3;
|
||||
rays += n * 2;
|
||||
|
||||
uint32_t num_steps = max_steps;
|
||||
|
||||
if (!first_pass) {
|
||||
uint32_t point_index = rays[0];
|
||||
num_steps = rays[1];
|
||||
xyzs += point_index * 3;
|
||||
dirs += point_index * 3;
|
||||
ts += point_index * 2;
|
||||
}
|
||||
|
||||
// ray marching
|
||||
const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];
|
||||
const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];
|
||||
const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;
|
||||
const float rH = 1 / (float)H;
|
||||
const float H3 = H * H * H;
|
||||
|
||||
const float near = nears[n];
|
||||
const float far = fars[n];
|
||||
const float noise = noises[n];
|
||||
|
||||
const float dt_min = 2 * SQRT3() / max_steps;
|
||||
const float dt_max = 2 * SQRT3() * bound / H;
|
||||
// const float dt_max = 1e10f;
|
||||
|
||||
float t0 = near;
|
||||
t0 += clamp(t0 * dt_gamma, dt_min, dt_max) * noise;
|
||||
float t = t0;
|
||||
uint32_t step = 0;
|
||||
|
||||
//if (t < far) printf("valid ray %d t=%f near=%f far=%f \n", n, t, near, far);
|
||||
|
||||
while (t < far && step < num_steps) {
|
||||
// current point
|
||||
const float x = clamp(ox + t * dx, -bound, bound);
|
||||
const float y = clamp(oy + t * dy, -bound, bound);
|
||||
const float z = clamp(oz + t * dz, -bound, bound);
|
||||
|
||||
float dt = clamp(t * dt_gamma, dt_min, dt_max);
|
||||
|
||||
// get mip level
|
||||
const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1]
|
||||
|
||||
const float mip_bound = fminf(scalbnf(1.0f, level), bound);
|
||||
const float mip_rbound = 1 / mip_bound;
|
||||
|
||||
// contraction
|
||||
float cx = x, cy = y, cz = z;
|
||||
const float mag = fmaxf(fabsf(x), fmaxf(fabsf(y), fabsf(z)));
|
||||
if (contract && mag > 1) {
|
||||
// L-INF norm
|
||||
const float Linf_scale = (2 - 1 / mag) / mag;
|
||||
cx *= Linf_scale;
|
||||
cy *= Linf_scale;
|
||||
cz *= Linf_scale;
|
||||
}
|
||||
|
||||
// convert to nearest grid position
|
||||
const int nx = clamp(0.5 * (cx * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
|
||||
const int ny = clamp(0.5 * (cy * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
|
||||
const int nz = clamp(0.5 * (cz * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
|
||||
|
||||
const uint32_t index = level * H3 + __morton3D(nx, ny, nz);
|
||||
const bool occ = grid[index / 8] & (1 << (index % 8));
|
||||
|
||||
// if occpuied, advance a small step, and write to output
|
||||
//if (n == 0) printf("t=%f density=%f vs thresh=%f step=%d\n", t, density, density_thresh, step);
|
||||
|
||||
if (occ) {
|
||||
step++;
|
||||
t += dt;
|
||||
if (!first_pass) {
|
||||
xyzs[0] = cx; // write contracted coordinates!
|
||||
xyzs[1] = cy;
|
||||
xyzs[2] = cz;
|
||||
dirs[0] = dx;
|
||||
dirs[1] = dy;
|
||||
dirs[2] = dz;
|
||||
ts[0] = t;
|
||||
ts[1] = dt;
|
||||
xyzs += 3;
|
||||
dirs += 3;
|
||||
ts += 2;
|
||||
}
|
||||
// contraction case: cannot apply voxel skipping.
|
||||
} else if (contract && mag > 1) {
|
||||
t += dt;
|
||||
// else, skip a large step (basically skip a voxel grid)
|
||||
} else {
|
||||
// calc distance to next voxel
|
||||
const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - cx) * rdx;
|
||||
const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - cy) * rdy;
|
||||
const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - cz) * rdz;
|
||||
|
||||
const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz)));
|
||||
// step until next voxel
|
||||
do {
|
||||
dt = clamp(t * dt_gamma, dt_min, dt_max);
|
||||
t += dt;
|
||||
} while (t < tt);
|
||||
}
|
||||
}
|
||||
|
||||
//printf("[n=%d] step=%d, near=%f, far=%f, dt=%f, num_steps=%f\n", n, step, near, far, dt_min, (far - near) / dt_min);
|
||||
|
||||
// write rays
|
||||
if (first_pass) {
|
||||
uint32_t point_index = atomicAdd(counter, step);
|
||||
rays[0] = point_index;
|
||||
rays[1] = step;
|
||||
}
|
||||
}
|
||||
|
||||
void march_rays_train(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor grid, const float bound, const bool contract, const float dt_gamma, const uint32_t max_steps, const uint32_t N, const uint32_t C, const uint32_t H, const at::Tensor nears, const at::Tensor fars, at::optional<at::Tensor> xyzs, at::optional<at::Tensor> dirs, at::optional<at::Tensor> ts, at::Tensor rays, at::Tensor counter, at::Tensor noises) {
|
||||
|
||||
static constexpr uint32_t N_THREAD = 128;
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
rays_o.scalar_type(), "march_rays_train", ([&] {
|
||||
kernel_march_rays_train<<<div_round_up(N, N_THREAD), N_THREAD>>>(rays_o.data_ptr<scalar_t>(), rays_d.data_ptr<scalar_t>(), grid.data_ptr<uint8_t>(), bound, contract, dt_gamma, max_steps, N, C, H, nears.data_ptr<scalar_t>(), fars.data_ptr<scalar_t>(),
|
||||
xyzs.has_value() ? xyzs.value().data_ptr<scalar_t>() : nullptr,
|
||||
dirs.has_value() ? dirs.value().data_ptr<scalar_t>() : nullptr,
|
||||
ts.has_value() ? ts.value().data_ptr<scalar_t>() : nullptr,
|
||||
rays.data_ptr<int>(), counter.data_ptr<int>(), noises.data_ptr<scalar_t>());
|
||||
}));
|
||||
}
|
||||
|
||||
|
||||
// sigmas: [M]
|
||||
// rgbs: [M, 3]
|
||||
// ts: [M, 2]
|
||||
// rays: [N, 2], offset, num_steps
|
||||
// weights: [M]
|
||||
// weights_sum: [N], final pixel alpha
|
||||
// depth: [N,]
|
||||
// image: [N, 3]
|
||||
template <typename scalar_t>
|
||||
__global__ void kernel_composite_rays_train_forward(
|
||||
const scalar_t * __restrict__ sigmas,
|
||||
const scalar_t * __restrict__ rgbs,
|
||||
const scalar_t * __restrict__ ts,
|
||||
const int * __restrict__ rays,
|
||||
const uint32_t M, const uint32_t N, const float T_thresh, const bool binarize,
|
||||
scalar_t * weights,
|
||||
scalar_t * weights_sum,
|
||||
scalar_t * depth,
|
||||
scalar_t * image
|
||||
) {
|
||||
// parallel per ray
|
||||
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (n >= N) return;
|
||||
|
||||
// locate
|
||||
uint32_t offset = rays[n * 2];
|
||||
uint32_t num_steps = rays[n * 2 + 1];
|
||||
|
||||
// empty ray, or ray that exceed max step count.
|
||||
if (num_steps == 0 || offset + num_steps > M) {
|
||||
weights_sum[n] = 0;
|
||||
depth[n] = 0;
|
||||
image[n * 3] = 0;
|
||||
image[n * 3 + 1] = 0;
|
||||
image[n * 3 + 2] = 0;
|
||||
return;
|
||||
}
|
||||
|
||||
ts += offset * 2;
|
||||
weights += offset;
|
||||
sigmas += offset;
|
||||
rgbs += offset * 3;
|
||||
|
||||
// accumulate
|
||||
uint32_t step = 0;
|
||||
|
||||
float T = 1.0f;
|
||||
float r = 0, g = 0, b = 0, ws = 0, d = 0;
|
||||
|
||||
while (step < num_steps) {
|
||||
|
||||
const float real_alpha = 1.0f - __expf(- sigmas[0] * ts[1]);
|
||||
const float alpha = binarize ? (real_alpha > 0.5 ? 1.0 : 0.0) : real_alpha;
|
||||
const float weight = alpha * T;
|
||||
|
||||
weights[0] = weight;
|
||||
|
||||
r += weight * rgbs[0];
|
||||
g += weight * rgbs[1];
|
||||
b += weight * rgbs[2];
|
||||
ws += weight;
|
||||
d += weight * ts[0];
|
||||
|
||||
T *= 1.0f - alpha;
|
||||
|
||||
// minimal remained transmittence
|
||||
if (T < T_thresh) break;
|
||||
|
||||
//printf("[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, d=%f\n", n, step, alpha, weight, T, sum_delta, d);
|
||||
|
||||
// locate
|
||||
weights++;
|
||||
sigmas++;
|
||||
rgbs += 3;
|
||||
ts += 2;
|
||||
|
||||
step++;
|
||||
}
|
||||
|
||||
//printf("[n=%d] rgb=(%f, %f, %f), d=%f\n", n, r, g, b, d);
|
||||
|
||||
// write
|
||||
weights_sum[n] = ws; // weights_sum
|
||||
depth[n] = d;
|
||||
image[n * 3] = r;
|
||||
image[n * 3 + 1] = g;
|
||||
image[n * 3 + 2] = b;
|
||||
}
|
||||
|
||||
|
||||
void composite_rays_train_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ts, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, const bool binarize, at::Tensor weights, at::Tensor weights_sum, at::Tensor depth, at::Tensor image) {
|
||||
|
||||
static constexpr uint32_t N_THREAD = 128;
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
sigmas.scalar_type(), "composite_rays_train_forward", ([&] {
|
||||
kernel_composite_rays_train_forward<<<div_round_up(N, N_THREAD), N_THREAD>>>(sigmas.data_ptr<scalar_t>(), rgbs.data_ptr<scalar_t>(), ts.data_ptr<scalar_t>(), rays.data_ptr<int>(), M, N, T_thresh, binarize, weights.data_ptr<scalar_t>(), weights_sum.data_ptr<scalar_t>(), depth.data_ptr<scalar_t>(), image.data_ptr<scalar_t>());
|
||||
}));
|
||||
}
|
||||
|
||||
|
||||
// grad_weights: [M,]
|
||||
// grad_weights_sum: [N,]
|
||||
// grad_image: [N, 3]
|
||||
// grad_depth: [N,]
|
||||
// sigmas: [M]
|
||||
// rgbs: [M, 3]
|
||||
// ts: [M, 2]
|
||||
// rays: [N, 2], offset, num_steps
|
||||
// weights_sum: [N,], weights_sum here
|
||||
// image: [N, 3]
|
||||
// grad_sigmas: [M]
|
||||
// grad_rgbs: [M, 3]
|
||||
template <typename scalar_t>
|
||||
__global__ void kernel_composite_rays_train_backward(
|
||||
const scalar_t * __restrict__ grad_weights,
|
||||
const scalar_t * __restrict__ grad_weights_sum,
|
||||
const scalar_t * __restrict__ grad_depth,
|
||||
const scalar_t * __restrict__ grad_image,
|
||||
const scalar_t * __restrict__ sigmas,
|
||||
const scalar_t * __restrict__ rgbs,
|
||||
const scalar_t * __restrict__ ts,
|
||||
const int * __restrict__ rays,
|
||||
const scalar_t * __restrict__ weights_sum,
|
||||
const scalar_t * __restrict__ depth,
|
||||
const scalar_t * __restrict__ image,
|
||||
const uint32_t M, const uint32_t N, const float T_thresh, const bool binarize,
|
||||
scalar_t * grad_sigmas,
|
||||
scalar_t * grad_rgbs
|
||||
) {
|
||||
// parallel per ray
|
||||
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (n >= N) return;
|
||||
|
||||
// locate
|
||||
uint32_t offset = rays[n * 2];
|
||||
uint32_t num_steps = rays[n * 2 + 1];
|
||||
|
||||
if (num_steps == 0 || offset + num_steps > M) return;
|
||||
|
||||
grad_weights += offset;
|
||||
grad_weights_sum += n;
|
||||
grad_depth += n;
|
||||
grad_image += n * 3;
|
||||
weights_sum += n;
|
||||
depth += n;
|
||||
image += n * 3;
|
||||
sigmas += offset;
|
||||
rgbs += offset * 3;
|
||||
ts += offset * 2;
|
||||
grad_sigmas += offset;
|
||||
grad_rgbs += offset * 3;
|
||||
|
||||
// accumulate
|
||||
uint32_t step = 0;
|
||||
|
||||
float T = 1.0f;
|
||||
const float r_final = image[0], g_final = image[1], b_final = image[2], ws_final = weights_sum[0], d_final = depth[0];
|
||||
float r = 0, g = 0, b = 0, ws = 0, d = 0;
|
||||
|
||||
while (step < num_steps) {
|
||||
|
||||
const float real_alpha = 1.0f - __expf(- sigmas[0] * ts[1]);
|
||||
const float alpha = binarize ? (real_alpha > 0.5 ? 1.0 : 0.0) : real_alpha;
|
||||
const float weight = alpha * T;
|
||||
|
||||
r += weight * rgbs[0];
|
||||
g += weight * rgbs[1];
|
||||
b += weight * rgbs[2];
|
||||
ws += weight;
|
||||
d += weight * ts[0];
|
||||
|
||||
T *= 1.0f - alpha;
|
||||
|
||||
// check https://note.kiui.moe/others/nerf_gradient/ for the gradient calculation.
|
||||
// write grad_rgbs
|
||||
grad_rgbs[0] = grad_image[0] * weight;
|
||||
grad_rgbs[1] = grad_image[1] * weight;
|
||||
grad_rgbs[2] = grad_image[2] * weight;
|
||||
|
||||
// write grad_sigmas
|
||||
grad_sigmas[0] = ts[1] * (
|
||||
grad_image[0] * (T * rgbs[0] - (r_final - r)) +
|
||||
grad_image[1] * (T * rgbs[1] - (g_final - g)) +
|
||||
grad_image[2] * (T * rgbs[2] - (b_final - b)) +
|
||||
(grad_weights_sum[0] + grad_weights[0]) * (T - (ws_final - ws)) +
|
||||
grad_depth[0] * (T * ts[0] - (d_final - d))
|
||||
);
|
||||
|
||||
//printf("[n=%d] num_steps=%d, T=%f, grad_sigmas=%f, r_final=%f, r=%f\n", n, step, T, grad_sigmas[0], r_final, r);
|
||||
// minimal remained transmittence
|
||||
if (T < T_thresh) break;
|
||||
|
||||
// locate
|
||||
sigmas++;
|
||||
rgbs += 3;
|
||||
ts += 2;
|
||||
grad_weights++;
|
||||
grad_sigmas++;
|
||||
grad_rgbs += 3;
|
||||
|
||||
step++;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void composite_rays_train_backward(const at::Tensor grad_weights, const at::Tensor grad_weights_sum, const at::Tensor grad_depth, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ts, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor depth, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, const bool binarize, at::Tensor grad_sigmas, at::Tensor grad_rgbs) {
|
||||
|
||||
static constexpr uint32_t N_THREAD = 128;
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
grad_image.scalar_type(), "composite_rays_train_backward", ([&] {
|
||||
kernel_composite_rays_train_backward<<<div_round_up(N, N_THREAD), N_THREAD>>>(grad_weights.data_ptr<scalar_t>(), grad_weights_sum.data_ptr<scalar_t>(), grad_depth.data_ptr<scalar_t>(), grad_image.data_ptr<scalar_t>(), sigmas.data_ptr<scalar_t>(), rgbs.data_ptr<scalar_t>(), ts.data_ptr<scalar_t>(), rays.data_ptr<int>(), weights_sum.data_ptr<scalar_t>(), depth.data_ptr<scalar_t>(), image.data_ptr<scalar_t>(), M, N, T_thresh, binarize, grad_sigmas.data_ptr<scalar_t>(), grad_rgbs.data_ptr<scalar_t>());
|
||||
}));
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////
|
||||
///////////// infernce /////////////
|
||||
////////////////////////////////////////////////////
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void kernel_march_rays(
|
||||
const uint32_t n_alive,
|
||||
const uint32_t n_step,
|
||||
const int* __restrict__ rays_alive,
|
||||
const scalar_t* __restrict__ rays_t,
|
||||
const scalar_t* __restrict__ rays_o,
|
||||
const scalar_t* __restrict__ rays_d,
|
||||
const float bound, const bool contract,
|
||||
const float dt_gamma, const uint32_t max_steps,
|
||||
const uint32_t C, const uint32_t H,
|
||||
const uint8_t * __restrict__ grid,
|
||||
const scalar_t* __restrict__ nears,
|
||||
const scalar_t* __restrict__ fars,
|
||||
scalar_t* xyzs, scalar_t* dirs, scalar_t* ts,
|
||||
const scalar_t* __restrict__ noises
|
||||
) {
|
||||
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (n >= n_alive) return;
|
||||
|
||||
const int index = rays_alive[n]; // ray id
|
||||
const float noise = noises[n];
|
||||
|
||||
// locate
|
||||
rays_o += index * 3;
|
||||
rays_d += index * 3;
|
||||
xyzs += n * n_step * 3;
|
||||
dirs += n * n_step * 3;
|
||||
ts += n * n_step * 2;
|
||||
|
||||
const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];
|
||||
const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];
|
||||
const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;
|
||||
const float rH = 1 / (float)H;
|
||||
const float H3 = H * H * H;
|
||||
|
||||
const float near = nears[index], far = fars[index];
|
||||
|
||||
const float dt_min = 2 * SQRT3() / max_steps;
|
||||
const float dt_max = 2 * SQRT3() * bound / H;
|
||||
// const float dt_max = 1e10f;
|
||||
|
||||
// march for n_step steps, record points
|
||||
float t = rays_t[index];
|
||||
t += clamp(t * dt_gamma, dt_min, dt_max) * noise;
|
||||
uint32_t step = 0;
|
||||
|
||||
while (t < far && step < n_step) {
|
||||
// current point
|
||||
const float x = clamp(ox + t * dx, -bound, bound);
|
||||
const float y = clamp(oy + t * dy, -bound, bound);
|
||||
const float z = clamp(oz + t * dz, -bound, bound);
|
||||
|
||||
float dt = clamp(t * dt_gamma, dt_min, dt_max);
|
||||
|
||||
// get mip level
|
||||
const int level = max(mip_from_pos(x, y, z, C), mip_from_dt(dt, H, C)); // range in [0, C - 1]
|
||||
|
||||
const float mip_bound = fminf(scalbnf(1, level), bound);
|
||||
const float mip_rbound = 1 / mip_bound;
|
||||
|
||||
// contraction
|
||||
float cx = x, cy = y, cz = z;
|
||||
const float mag = fmaxf(fabsf(x), fmaxf(fabsf(y), fabsf(z)));
|
||||
if (contract && mag > 1) {
|
||||
// L-INF norm
|
||||
const float Linf_scale = (2 - 1 / mag) / mag;
|
||||
cx *= Linf_scale;
|
||||
cy *= Linf_scale;
|
||||
cz *= Linf_scale;
|
||||
}
|
||||
|
||||
// convert to nearest grid position
|
||||
const int nx = clamp(0.5 * (cx * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
|
||||
const int ny = clamp(0.5 * (cy * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
|
||||
const int nz = clamp(0.5 * (cz * mip_rbound + 1) * H, 0.0f, (float)(H - 1));
|
||||
|
||||
const uint32_t index = level * H3 + __morton3D(nx, ny, nz);
|
||||
const bool occ = grid[index / 8] & (1 << (index % 8));
|
||||
|
||||
// if occpuied, advance a small step, and write to output
|
||||
if (occ) {
|
||||
// write step
|
||||
xyzs[0] = cx;
|
||||
xyzs[1] = cy;
|
||||
xyzs[2] = cz;
|
||||
dirs[0] = dx;
|
||||
dirs[1] = dy;
|
||||
dirs[2] = dz;
|
||||
// calc dt
|
||||
t += dt;
|
||||
ts[0] = t;
|
||||
ts[1] = dt;
|
||||
// step
|
||||
xyzs += 3;
|
||||
dirs += 3;
|
||||
ts += 2;
|
||||
step++;
|
||||
|
||||
// contraction case
|
||||
} else if (contract && mag > 1) {
|
||||
t += dt;
|
||||
// else, skip a large step (basically skip a voxel grid)
|
||||
} else {
|
||||
// calc distance to next voxel
|
||||
const float tx = (((nx + 0.5f + 0.5f * signf(dx)) * rH * 2 - 1) * mip_bound - cx) * rdx;
|
||||
const float ty = (((ny + 0.5f + 0.5f * signf(dy)) * rH * 2 - 1) * mip_bound - cy) * rdy;
|
||||
const float tz = (((nz + 0.5f + 0.5f * signf(dz)) * rH * 2 - 1) * mip_bound - cz) * rdz;
|
||||
const float tt = t + fmaxf(0.0f, fminf(tx, fminf(ty, tz)));
|
||||
// step until next voxel
|
||||
do {
|
||||
dt = clamp(t * dt_gamma, dt_min, dt_max);
|
||||
t += dt;
|
||||
} while (t < tt);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void march_rays(const uint32_t n_alive, const uint32_t n_step, const at::Tensor rays_alive, const at::Tensor rays_t, const at::Tensor rays_o, const at::Tensor rays_d, const float bound, const bool contract, const float dt_gamma, const uint32_t max_steps, const uint32_t C, const uint32_t H, const at::Tensor grid, const at::Tensor near, const at::Tensor far, at::Tensor xyzs, at::Tensor dirs, at::Tensor ts, at::Tensor noises) {
|
||||
static constexpr uint32_t N_THREAD = 128;
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
rays_o.scalar_type(), "march_rays", ([&] {
|
||||
kernel_march_rays<<<div_round_up(n_alive, N_THREAD), N_THREAD>>>(n_alive, n_step, rays_alive.data_ptr<int>(), rays_t.data_ptr<scalar_t>(), rays_o.data_ptr<scalar_t>(), rays_d.data_ptr<scalar_t>(), bound, contract, dt_gamma, max_steps, C, H, grid.data_ptr<uint8_t>(), near.data_ptr<scalar_t>(), far.data_ptr<scalar_t>(), xyzs.data_ptr<scalar_t>(), dirs.data_ptr<scalar_t>(), ts.data_ptr<scalar_t>(), noises.data_ptr<scalar_t>());
|
||||
}));
|
||||
}
|
||||
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void kernel_composite_rays(
|
||||
const uint32_t n_alive,
|
||||
const uint32_t n_step,
|
||||
const float T_thresh, const bool binarize,
|
||||
int* rays_alive,
|
||||
scalar_t* rays_t,
|
||||
const scalar_t* __restrict__ sigmas,
|
||||
const scalar_t* __restrict__ rgbs,
|
||||
const scalar_t* __restrict__ ts,
|
||||
scalar_t* weights_sum, scalar_t* depth, scalar_t* image
|
||||
) {
|
||||
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (n >= n_alive) return;
|
||||
|
||||
const int index = rays_alive[n]; // ray id
|
||||
|
||||
// locate
|
||||
sigmas += n * n_step;
|
||||
rgbs += n * n_step * 3;
|
||||
ts += n * n_step * 2;
|
||||
|
||||
rays_t += index;
|
||||
weights_sum += index;
|
||||
depth += index;
|
||||
image += index * 3;
|
||||
|
||||
float t;
|
||||
float d = depth[0], r = image[0], g = image[1], b = image[2], weight_sum = weights_sum[0];
|
||||
|
||||
// accumulate
|
||||
uint32_t step = 0;
|
||||
while (step < n_step) {
|
||||
|
||||
// ray is terminated if t == 0
|
||||
if (ts[0] == 0) break;
|
||||
|
||||
const float real_alpha = 1.0f - __expf(- sigmas[0] * ts[1]);
|
||||
const float alpha = binarize ? (real_alpha > 0.5 ? 1.0 : 0.0) : real_alpha;
|
||||
|
||||
/*
|
||||
T_0 = 1; T_i = \prod_{j=0}^{i-1} (1 - alpha_j)
|
||||
w_i = alpha_i * T_i
|
||||
-->
|
||||
T_i = 1 - \sum_{j=0}^{i-1} w_j
|
||||
*/
|
||||
const float T = 1 - weight_sum;
|
||||
const float weight = alpha * T;
|
||||
weight_sum += weight;
|
||||
|
||||
t = ts[0];
|
||||
d += weight * t; // real depth
|
||||
r += weight * rgbs[0];
|
||||
g += weight * rgbs[1];
|
||||
b += weight * rgbs[2];
|
||||
|
||||
//printf("[n=%d] num_steps=%d, alpha=%f, w=%f, T=%f, sum_dt=%f, d=%f\n", n, step, alpha, weight, T, sum_delta, d);
|
||||
|
||||
// ray is terminated if T is too small
|
||||
// use a larger bound to further accelerate inference
|
||||
if (T < T_thresh) break;
|
||||
|
||||
// locate
|
||||
sigmas++;
|
||||
rgbs += 3;
|
||||
ts += 2;
|
||||
step++;
|
||||
}
|
||||
|
||||
//printf("[n=%d] rgb=(%f, %f, %f), d=%f\n", n, r, g, b, d);
|
||||
|
||||
// rays_alive = -1 means ray is terminated early.
|
||||
if (step < n_step) {
|
||||
rays_alive[n] = -1;
|
||||
} else {
|
||||
rays_t[0] = t;
|
||||
}
|
||||
|
||||
weights_sum[0] = weight_sum; // this is the thing I needed!
|
||||
depth[0] = d;
|
||||
image[0] = r;
|
||||
image[1] = g;
|
||||
image[2] = b;
|
||||
}
|
||||
|
||||
|
||||
void composite_rays(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, const bool binarize, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor ts, at::Tensor weights, at::Tensor depth, at::Tensor image) {
|
||||
static constexpr uint32_t N_THREAD = 128;
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
|
||||
image.scalar_type(), "composite_rays", ([&] {
|
||||
kernel_composite_rays<<<div_round_up(n_alive, N_THREAD), N_THREAD>>>(n_alive, n_step, T_thresh, binarize, rays_alive.data_ptr<int>(), rays_t.data_ptr<scalar_t>(), sigmas.data_ptr<scalar_t>(), rgbs.data_ptr<scalar_t>(), ts.data_ptr<scalar_t>(), weights.data_ptr<scalar_t>(), depth.data_ptr<scalar_t>(), image.data_ptr<scalar_t>());
|
||||
}));
|
||||
}
|
||||
19
raymarching/src/raymarching.h
Normal file
19
raymarching/src/raymarching.h
Normal file
@@ -0,0 +1,19 @@
|
||||
#pragma once
|
||||
|
||||
#include <stdint.h>
|
||||
#include <torch/torch.h>
|
||||
|
||||
|
||||
void near_far_from_aabb(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor aabb, const uint32_t N, const float min_near, at::Tensor nears, at::Tensor fars);
|
||||
void sph_from_ray(const at::Tensor rays_o, const at::Tensor rays_d, const float radius, const uint32_t N, at::Tensor coords);
|
||||
void morton3D(const at::Tensor coords, const uint32_t N, at::Tensor indices);
|
||||
void morton3D_invert(const at::Tensor indices, const uint32_t N, at::Tensor coords);
|
||||
void packbits(const at::Tensor grid, const uint32_t N, const float density_thresh, at::Tensor bitfield);
|
||||
void flatten_rays(const at::Tensor rays, const uint32_t N, const uint32_t M, at::Tensor res);
|
||||
|
||||
void march_rays_train(const at::Tensor rays_o, const at::Tensor rays_d, const at::Tensor grid, const float bound, const bool contract, const float dt_gamma, const uint32_t max_steps, const uint32_t N, const uint32_t C, const uint32_t H, const at::Tensor nears, const at::Tensor fars, at::optional<at::Tensor> xyzs, at::optional<at::Tensor> dirs, at::optional<at::Tensor> ts, at::Tensor rays, at::Tensor counter, at::Tensor noises);
|
||||
void composite_rays_train_forward(const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ts, const at::Tensor rays, const uint32_t M, const uint32_t N, const float T_thresh, const bool binarize, at::Tensor weights, at::Tensor weights_sum, at::Tensor depth, at::Tensor image);
|
||||
void composite_rays_train_backward(const at::Tensor grad_weights, const at::Tensor grad_weights_sum, const at::Tensor grad_depth, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor ts, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor depth, const at::Tensor image, const uint32_t M, const uint32_t N, const float T_thresh, const bool binarize, at::Tensor grad_sigmas, at::Tensor grad_rgbs);
|
||||
|
||||
void march_rays(const uint32_t n_alive, const uint32_t n_step, const at::Tensor rays_alive, const at::Tensor rays_t, const at::Tensor rays_o, const at::Tensor rays_d, const float bound, const bool contract, const float dt_gamma, const uint32_t max_steps, const uint32_t C, const uint32_t H, const at::Tensor grid, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor ts, at::Tensor noises);
|
||||
void composite_rays(const uint32_t n_alive, const uint32_t n_step, const float T_thresh, const bool binarize, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor ts, at::Tensor weights_sum, at::Tensor depth, at::Tensor image);
|
||||
Reference in New Issue
Block a user