3from typing
import Tuple, Optional
4from jaxtyping
import Float32, Int32, Bool
5import triro.backend.ops
as hops
29 self.
mesh_vertices = torch.from_numpy(mesh.vertices).float().contiguous().cuda()
31 self.
mesh_faces = torch.from_numpy(mesh.faces).int().contiguous().cuda()
32 elif 'vertices' in kwargs
and 'faces' in kwargs:
33 vertices = kwargs[
'vertices']
34 faces = kwargs[
'faces']
39 self.
mesh_faces = faces.int().contiguous().cuda()
41 raise ValueError(
"Either 'mesh' or 'vertices' and 'faces' must be provided.")
56 vertices: Float32[torch.Tensor,
"n 3"],
57 faces: Int32[torch.Tensor,
"f 3"]):
62 self.
mesh_faces = faces.int().contiguous().cuda()
79 origins: Float32[torch.Tensor,
"*b 3"],
80 directions: Float32[torch.Tensor,
"*b 3"],
81 ) -> Bool[torch.Tensor,
"*b"]:
82 return hops.intersects_any(self.
as_wrapper, origins, directions)
92 origins: Float32[torch.Tensor,
"*b 3"],
93 directions: Float32[torch.Tensor,
"*b 3"],
94 ) -> Int32[torch.Tensor,
"*b"]:
95 return hops.intersects_first(self.
as_wrapper, origins, directions)
119 origins: Float32[torch.Tensor,
"*b 3"],
120 directions: Float32[torch.Tensor,
"*b 3"],
121 stream_compaction: bool =
False,
124 Bool[torch.Tensor,
"*b"],
125 Bool[torch.Tensor,
"*b"],
126 Int32[torch.Tensor,
"*b"],
127 Float32[torch.Tensor,
"*b 3"],
128 Float32[torch.Tensor,
"*b 2"],
131 Bool[torch.Tensor,
"*b"],
132 Bool[torch.Tensor,
"h"],
133 Int32[torch.Tensor,
"h"],
134 Int32[torch.Tensor,
"h"],
135 Float32[torch.Tensor,
"h 3"],
136 Float32[torch.Tensor,
"h 2"],
139 hit, front, tri_idx, loc, uv = hops.intersects_closest(
142 if stream_compaction:
143 ray_idx = torch.arange(0, hit.shape.numel()).cuda().int()[hit.reshape(-1)]
144 return hit, front[hit], ray_idx, tri_idx[hit], loc[hit], uv[hit]
146 return hit, front, tri_idx, loc, uv
159 origins: Float32[torch.Tensor,
"*b 3"],
160 directions: Float32[torch.Tensor,
"*b 3"],
162 Float32[torch.Tensor,
"h 3"], Int32[torch.Tensor,
"h"], Int32[torch.Tensor,
"h"]
164 return hops.intersects_location(self.
as_wrapper, origins, directions)
174 origins: Float32[torch.Tensor,
"*b 3"],
175 directions: Float32[torch.Tensor,
"*b 3"],
176 ) -> Int32[torch.Tensor,
"*b 3"]:
177 return hops.intersects_count(self.
as_wrapper, origins, directions)
193 origins: Float32[torch.Tensor,
"*b 3"],
194 directions: Float32[torch.Tensor,
"*b 3"],
195 return_locations: bool =
False,
196 multiple_hits: bool =
True,
199 Int32[torch.Tensor,
"h"],
200 Int32[torch.Tensor,
"h"],
201 Float32[torch.Tensor,
"h 3"],
204 Int32[torch.Tensor,
"h"], Int32[torch.Tensor,
"h"]
208 loc, ray_idx, tri_idx = hops.intersects_location(
212 return tri_idx, ray_idx, loc
214 return tri_idx, ray_idx
216 hit, _, tri_idx, loc, _ = hops.intersects_closest(
219 ray_idx = torch.arange(0, hit.shape.numel()).cuda().int()[hit.reshape(-1)]
221 return tri_idx[hit], ray_idx, loc[hit]
223 return tri_idx[hit], ray_idx
233 points: Float32[torch.Tensor,
"*b 3"],
234 check_direction: Optional[Float32[torch.Tensor,
"3"]] =
None,
235 ) -> Bool[torch.Tensor,
"*b 3"]:
236 contains = torch.zeros(points.shape[:-1], dtype=torch.bool)
242 if not inside_aabb.any():
244 default_direction = torch.Tensor(
245 [0.4395064455, 0.617598629942, 0.652231566745]
248 if check_direction
is None:
249 ray_directions = torch.tile(default_direction, [*contains.shape, 1])
251 ray_directions = torch.tile(check_direction, [*contains.shape, 1])
253 hit_count = torch.stack(
255 hops.intersects_count(self.
as_wrapper, points, ray_directions),
256 hops.intersects_count(self.
as_wrapper, points, -ray_directions),
261 hit_count_mod_2 = torch.remainder(hit_count, 2)
262 agree = torch.equal(hit_count_mod_2[0], hit_count_mod_2[1])
264 contain = inside_aabb & agree & hit_count_mod_2[0] == 1
266 broken_mask = ~agree & (hit_count == 0).any(dim=-1)
267 if not broken_mask.any():
270 if check_direction
is None:
271 new_direction = (torch.rand(3) - 0.5).cuda()
272 contains[broken_mask] = self.
contains_points(self, points, new_direction)
279 self.
_inner = hops.get_module().OptixAccelStructureWrapperCPP()
282 self.
_inner.freeAccelStructure()
286 vertices: Float32[torch.Tensor,
"nvert 3"],
287 faces: Int32[torch.Tensor,
"nface 3"],
289 self.
_inner.buildAccelStructure(vertices, faces)
build_accel_structure(self, Float32[torch.Tensor, "nvert 3"] vertices, Int32[torch.Tensor, "nface 3"] faces)
A class for performing ray-mesh intersection tests using OptiX acceleration structure.
Int32[torch.Tensor, "*b"] intersects_first(self, Float32[torch.Tensor, "*b 3"] origins, Float32[torch.Tensor, "*b 3"] directions)
Find the index of the first intersection for each ray.
Bool[torch.Tensor, "*b 3"] contains_points(self, Float32[torch.Tensor, "*b 3"] points, Optional[Float32[torch.Tensor, "3"]] check_direction=None)
Check if the points are contained within the mesh.
__init__(self, **kwargs)
Initialize the RayMeshIntersector class.
( Tuple intersects_closest[ Bool[torch.Tensor, "*b"], # hit Bool[torch.Tensor, "*b"], # front Int32[torch.Tensor, "*b"], # triangle index Float32[torch.Tensor, "*b 3"], # intersect location Float32[torch.Tensor, "*b 2"], # uv]|Tuple[ Bool[torch.Tensor, "*b"], # hit Bool[torch.Tensor, "h"], # front Int32[torch.Tensor, "h"], # ray index Int32[torch.Tensor, "h"], # triangle index Float32[torch.Tensor, "h 3"], # intersect location Float32[torch.Tensor, "h 2"], # uv:])(self, Float32[torch.Tensor, "*b 3"] origins, Float32[torch.Tensor, "*b 3"] directions, bool stream_compaction=False)
Find the closest intersection for each ray.
( Tuple[ Int32[torch.Tensor, "h"], # hit triangle indices Int32[torch.Tensor, "h"], # ray indices Float32[torch.Tensor, "h 3"], # hit location]|Tuple[ Int32[torch.Tensor, "h"], Int32[torch.Tensor, "h"]] # hit triangle indices and ray indices) intersects_id(self, Float32[torch.Tensor, "*b 3"] origins, Float32[torch.Tensor, "*b 3"] directions, bool return_locations=False, bool multiple_hits=True)
Find the intersection indices for each ray.
Tuple[ Float32[torch.Tensor, "h 3"], Int32[torch.Tensor, "h"], Int32[torch.Tensor, "h"]] intersects_location(self, Float32[torch.Tensor, "*b 3"] origins, Float32[torch.Tensor, "*b 3"] directions)
Find the intersection location for each ray.
Int32[torch.Tensor, "*b 3"] intersects_count(self, Float32[torch.Tensor, "*b 3"] origins, Float32[torch.Tensor, "*b 3"] directions)
Count the number of intersections for each ray.
update_raw(self, Float32[torch.Tensor, "n 3"] vertices, Int32[torch.Tensor, "f 3"] faces)
Update the raw mesh data.
Bool[torch.Tensor, "*b"] intersects_any(self, Float32[torch.Tensor, "*b 3"] origins, Float32[torch.Tensor, "*b 3"] directions)
Check if any intersections occur for each ray.