69 lines
2.2 KiB
Python
69 lines
2.2 KiB
Python
import taichi as ti
|
|
import torch
|
|
from taichi.math import vec3
|
|
from torch.cuda.amp import custom_fwd
|
|
|
|
from .utils import NEAR_DISTANCE
|
|
|
|
|
|
@ti.kernel
|
|
def simple_ray_aabb_intersec_taichi_forward(
|
|
hits_t: ti.types.ndarray(ndim=2),
|
|
rays_o: ti.types.ndarray(ndim=2),
|
|
rays_d: ti.types.ndarray(ndim=2),
|
|
centers: ti.types.ndarray(ndim=2),
|
|
half_sizes: ti.types.ndarray(ndim=2)):
|
|
|
|
for r in ti.ndrange(hits_t.shape[0]):
|
|
ray_o = vec3([rays_o[r, 0], rays_o[r, 1], rays_o[r, 2]])
|
|
ray_d = vec3([rays_d[r, 0], rays_d[r, 1], rays_d[r, 2]])
|
|
inv_d = 1.0 / ray_d
|
|
|
|
center = vec3([centers[0, 0], centers[0, 1], centers[0, 2]])
|
|
half_size = vec3(
|
|
[half_sizes[0, 0], half_sizes[0, 1], half_sizes[0, 1]])
|
|
|
|
t_min = (center - half_size - ray_o) * inv_d
|
|
t_max = (center + half_size - ray_o) * inv_d
|
|
|
|
_t1 = ti.min(t_min, t_max)
|
|
_t2 = ti.max(t_min, t_max)
|
|
t1 = _t1.max()
|
|
t2 = _t2.min()
|
|
|
|
if t2 > 0.0:
|
|
hits_t[r, 0, 0] = ti.max(t1, NEAR_DISTANCE)
|
|
hits_t[r, 0, 1] = t2
|
|
|
|
|
|
class RayAABBIntersector(torch.autograd.Function):
|
|
"""
|
|
Computes the intersections of rays and axis-aligned voxels.
|
|
|
|
Inputs:
|
|
rays_o: (N_rays, 3) ray origins
|
|
rays_d: (N_rays, 3) ray directions
|
|
centers: (N_voxels, 3) voxel centers
|
|
half_sizes: (N_voxels, 3) voxel half sizes
|
|
max_hits: maximum number of intersected voxels to keep for one ray
|
|
(for a cubic scene, this is at most 3*N_voxels^(1/3)-2)
|
|
|
|
Outputs:
|
|
hits_cnt: (N_rays) number of hits for each ray
|
|
(followings are from near to far)
|
|
hits_t: (N_rays, max_hits, 2) hit t's (-1 if no hit)
|
|
hits_voxel_idx: (N_rays, max_hits) hit voxel indices (-1 if no hit)
|
|
"""
|
|
|
|
@staticmethod
|
|
@custom_fwd(cast_inputs=torch.float32)
|
|
def forward(ctx, rays_o, rays_d, center, half_size, max_hits):
|
|
hits_t = (torch.zeros(
|
|
rays_o.size(0), 1, 2, device=rays_o.device, dtype=torch.float32) -
|
|
1).contiguous()
|
|
|
|
simple_ray_aabb_intersec_taichi_forward(hits_t, rays_o, rays_d, center,
|
|
half_size)
|
|
|
|
return None, hits_t, None
|