import taichi as ti import torch from taichi.math import vec3 from torch.cuda.amp import custom_fwd from .utils import NEAR_DISTANCE @ti.kernel def simple_ray_aabb_intersec_taichi_forward( hits_t: ti.types.ndarray(ndim=2), rays_o: ti.types.ndarray(ndim=2), rays_d: ti.types.ndarray(ndim=2), centers: ti.types.ndarray(ndim=2), half_sizes: ti.types.ndarray(ndim=2)): for r in ti.ndrange(hits_t.shape[0]): ray_o = vec3([rays_o[r, 0], rays_o[r, 1], rays_o[r, 2]]) ray_d = vec3([rays_d[r, 0], rays_d[r, 1], rays_d[r, 2]]) inv_d = 1.0 / ray_d center = vec3([centers[0, 0], centers[0, 1], centers[0, 2]]) half_size = vec3( [half_sizes[0, 0], half_sizes[0, 1], half_sizes[0, 1]]) t_min = (center - half_size - ray_o) * inv_d t_max = (center + half_size - ray_o) * inv_d _t1 = ti.min(t_min, t_max) _t2 = ti.max(t_min, t_max) t1 = _t1.max() t2 = _t2.min() if t2 > 0.0: hits_t[r, 0, 0] = ti.max(t1, NEAR_DISTANCE) hits_t[r, 0, 1] = t2 class RayAABBIntersector(torch.autograd.Function): """ Computes the intersections of rays and axis-aligned voxels. Inputs: rays_o: (N_rays, 3) ray origins rays_d: (N_rays, 3) ray directions centers: (N_voxels, 3) voxel centers half_sizes: (N_voxels, 3) voxel half sizes max_hits: maximum number of intersected voxels to keep for one ray (for a cubic scene, this is at most 3*N_voxels^(1/3)-2) Outputs: hits_cnt: (N_rays) number of hits for each ray (followings are from near to far) hits_t: (N_rays, max_hits, 2) hit t's (-1 if no hit) hits_voxel_idx: (N_rays, max_hits) hit voxel indices (-1 if no hit) """ @staticmethod @custom_fwd(cast_inputs=torch.float32) def forward(ctx, rays_o, rays_d, center, half_size, max_hits): hits_t = (torch.zeros( rays_o.size(0), 1, 2, device=rays_o.device, dtype=torch.float32) - 1).contiguous() simple_ray_aabb_intersec_taichi_forward(hits_t, rays_o, rays_d, center, half_size) return None, hits_t, None