Files
Magic123/taichi_modules/intersection.py
Guocheng Qian 13e18567fa first commit
2023-08-02 19:51:43 -07:00

69 lines
2.2 KiB
Python

import taichi as ti
import torch
from taichi.math import vec3
from torch.cuda.amp import custom_fwd
from .utils import NEAR_DISTANCE
@ti.kernel
def simple_ray_aabb_intersec_taichi_forward(
hits_t: ti.types.ndarray(ndim=2),
rays_o: ti.types.ndarray(ndim=2),
rays_d: ti.types.ndarray(ndim=2),
centers: ti.types.ndarray(ndim=2),
half_sizes: ti.types.ndarray(ndim=2)):
for r in ti.ndrange(hits_t.shape[0]):
ray_o = vec3([rays_o[r, 0], rays_o[r, 1], rays_o[r, 2]])
ray_d = vec3([rays_d[r, 0], rays_d[r, 1], rays_d[r, 2]])
inv_d = 1.0 / ray_d
center = vec3([centers[0, 0], centers[0, 1], centers[0, 2]])
half_size = vec3(
[half_sizes[0, 0], half_sizes[0, 1], half_sizes[0, 1]])
t_min = (center - half_size - ray_o) * inv_d
t_max = (center + half_size - ray_o) * inv_d
_t1 = ti.min(t_min, t_max)
_t2 = ti.max(t_min, t_max)
t1 = _t1.max()
t2 = _t2.min()
if t2 > 0.0:
hits_t[r, 0, 0] = ti.max(t1, NEAR_DISTANCE)
hits_t[r, 0, 1] = t2
class RayAABBIntersector(torch.autograd.Function):
"""
Computes the intersections of rays and axis-aligned voxels.
Inputs:
rays_o: (N_rays, 3) ray origins
rays_d: (N_rays, 3) ray directions
centers: (N_voxels, 3) voxel centers
half_sizes: (N_voxels, 3) voxel half sizes
max_hits: maximum number of intersected voxels to keep for one ray
(for a cubic scene, this is at most 3*N_voxels^(1/3)-2)
Outputs:
hits_cnt: (N_rays) number of hits for each ray
(followings are from near to far)
hits_t: (N_rays, max_hits, 2) hit t's (-1 if no hit)
hits_voxel_idx: (N_rays, max_hits) hit voxel indices (-1 if no hit)
"""
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, rays_o, rays_d, center, half_size, max_hits):
hits_t = (torch.zeros(
rays_o.size(0), 1, 2, device=rays_o.device, dtype=torch.float32) -
1).contiguous()
simple_ray_aabb_intersec_taichi_forward(hits_t, rays_o, rays_d, center,
half_size)
return None, hits_t, None