triro 1.3.0
A Python Ray-Mesh Intersector in OptiX
Loading...
Searching...
No Matches
ray_optix.py
Go to the documentation of this file.
1import torch
2import trimesh
3from typing import Tuple, Optional
4from jaxtyping import Float32, Int32, Bool
5import triro.backend.ops as hops
6
7
8
19
20
24 def __init__(self, **kwargs):
25 if 'mesh' in kwargs:
26 mesh = kwargs['mesh']
27 # mesh vertices
28 # [n, 3] float32 on the device
29 self.mesh_vertices = torch.from_numpy(mesh.vertices).float().contiguous().cuda()
30 # [n, 3] int32 on the device
31 self.mesh_faces = torch.from_numpy(mesh.faces).int().contiguous().cuda()
32 elif 'vertices' in kwargs and 'faces' in kwargs:
33 vertices = kwargs['vertices']
34 faces = kwargs['faces']
35 # mesh vertices
36 # [n, 3] float32 on the device
37 self.mesh_vertices = vertices.float().contiguous().cuda()
38 # [n, 3] int32 on the device
39 self.mesh_faces = faces.int().contiguous().cuda()
40 else:
41 raise ValueError("Either 'mesh' or 'vertices' and 'faces' must be provided.")
42 # ([3], [3])
43 self.mesh_aabb = (
44 torch.min(self.mesh_vertices, dim=0)[0],
45 torch.max(self.mesh_vertices, dim=0)[0],
46 )
47 # build acceleration structure
49 self.as_wrapper.build_accel_structure(self.mesh_vertices, self.mesh_faces)
50
51
55 def update_raw(self,
56 vertices: Float32[torch.Tensor, "n 3"],
57 faces: Int32[torch.Tensor, "f 3"]):
58 # mesh vertices
59 # [n, 3] float32 on the device
60 self.mesh_vertices = vertices.float().contiguous().cuda()
61 # [f, 3] int32 on the device
62 self.mesh_faces = faces.int().contiguous().cuda()
63 # ([3], [3])
64 self.mesh_aabb = (
65 torch.min(self.mesh_vertices, dim=0)[0],
66 torch.max(self.mesh_vertices, dim=0)[0],
67 )
68 # build acceleration structure
69 self.as_wrapper.build_accel_structure(self.mesh_vertices, self.mesh_faces)
70
71
78 self,
79 origins: Float32[torch.Tensor, "*b 3"],
80 directions: Float32[torch.Tensor, "*b 3"],
81 ) -> Bool[torch.Tensor, "*b"]:
82 return hops.intersects_any(self.as_wrapper, origins, directions)
83
84
91 self,
92 origins: Float32[torch.Tensor, "*b 3"],
93 directions: Float32[torch.Tensor, "*b 3"],
94 ) -> Int32[torch.Tensor, "*b"]:
95 return hops.intersects_first(self.as_wrapper, origins, directions)
96
97
118 self,
119 origins: Float32[torch.Tensor, "*b 3"],
120 directions: Float32[torch.Tensor, "*b 3"],
121 stream_compaction: bool = False,
122 ) -> (
123 Tuple[
124 Bool[torch.Tensor, "*b"], # hit
125 Bool[torch.Tensor, "*b"], # front
126 Int32[torch.Tensor, "*b"], # triangle index
127 Float32[torch.Tensor, "*b 3"], # intersect location
128 Float32[torch.Tensor, "*b 2"], # uv
129 ]
130 | Tuple[
131 Bool[torch.Tensor, "*b"], # hit
132 Bool[torch.Tensor, "h"], # front
133 Int32[torch.Tensor, "h"], # ray index
134 Int32[torch.Tensor, "h"], # triangle index
135 Float32[torch.Tensor, "h 3"], # intersect location
136 Float32[torch.Tensor, "h 2"], # uv:
137 ]
138 ):
139 hit, front, tri_idx, loc, uv = hops.intersects_closest(
140 self.as_wrapper, origins, directions
141 )
142 if stream_compaction:
143 ray_idx = torch.arange(0, hit.shape.numel()).cuda().int()[hit.reshape(-1)]
144 return hit, front[hit], ray_idx, tri_idx[hit], loc[hit], uv[hit]
145 else:
146 return hit, front, tri_idx, loc, uv
147
148
158 self,
159 origins: Float32[torch.Tensor, "*b 3"],
160 directions: Float32[torch.Tensor, "*b 3"],
161 ) -> Tuple[
162 Float32[torch.Tensor, "h 3"], Int32[torch.Tensor, "h"], Int32[torch.Tensor, "h"]
163 ]:
164 return hops.intersects_location(self.as_wrapper, origins, directions)
165
166
173 self,
174 origins: Float32[torch.Tensor, "*b 3"],
175 directions: Float32[torch.Tensor, "*b 3"],
176 ) -> Int32[torch.Tensor, "*b 3"]:
177 return hops.intersects_count(self.as_wrapper, origins, directions)
178
179
192 self,
193 origins: Float32[torch.Tensor, "*b 3"],
194 directions: Float32[torch.Tensor, "*b 3"],
195 return_locations: bool = False,
196 multiple_hits: bool = True,
197 ) -> (
198 Tuple[
199 Int32[torch.Tensor, "h"], # hit triangle indices
200 Int32[torch.Tensor, "h"], # ray indices
201 Float32[torch.Tensor, "h 3"], # hit location
202 ]
203 | Tuple[
204 Int32[torch.Tensor, "h"], Int32[torch.Tensor, "h"]
205 ] # hit triangle indices and ray indices
206 ):
207 if multiple_hits:
208 loc, ray_idx, tri_idx = hops.intersects_location(
209 self.as_wrapper, origins, directions
210 )
211 if return_locations:
212 return tri_idx, ray_idx, loc
213 else:
214 return tri_idx, ray_idx
215 else:
216 hit, _, tri_idx, loc, _ = hops.intersects_closest(
217 self.as_wrapper, origins, directions
218 )
219 ray_idx = torch.arange(0, hit.shape.numel()).cuda().int()[hit.reshape(-1)]
220 if return_locations:
221 return tri_idx[hit], ray_idx, loc[hit]
222 else:
223 return tri_idx[hit], ray_idx
224
225
232 self,
233 points: Float32[torch.Tensor, "*b 3"],
234 check_direction: Optional[Float32[torch.Tensor, "3"]] = None,
235 ) -> Bool[torch.Tensor, "*b 3"]:
236 contains = torch.zeros(points.shape[:-1], dtype=torch.bool)
237 # check if points are in the aabb
238 inside_aabb = ~(
239 (~(points > self.mesh_aabb[0])).any()
240 | (~(points < self.mesh_aabb[1])).any()
241 )
242 if not inside_aabb.any():
243 return contains
244 default_direction = torch.Tensor(
245 [0.4395064455, 0.617598629942, 0.652231566745]
246 ).cuda()
247 # overwrite default direction
248 if check_direction is None:
249 ray_directions = torch.tile(default_direction, [*contains.shape, 1])
250 else:
251 ray_directions = torch.tile(check_direction, [*contains.shape, 1])
252 # ray trace in two directions
253 hit_count = torch.stack(
254 [
255 hops.intersects_count(self.as_wrapper, points, ray_directions),
256 hops.intersects_count(self.as_wrapper, points, -ray_directions),
257 ],
258 dim=0,
259 )
260 # if hit count in two directions are all odd number then the point is likely to be inside the mesh
261 hit_count_mod_2 = torch.remainder(hit_count, 2)
262 agree = torch.equal(hit_count_mod_2[0], hit_count_mod_2[1])
263
264 contain = inside_aabb & agree & hit_count_mod_2[0] == 1
265
266 broken_mask = ~agree & (hit_count == 0).any(dim=-1)
267 if not broken_mask.any():
268 return contain
269
270 if check_direction is None:
271 new_direction = (torch.rand(3) - 0.5).cuda()
272 contains[broken_mask] = self.contains_points(self, points, new_direction)
273
274 return contains
275
276
278 def __init__(self):
279 self._inner = hops.get_module().OptixAccelStructureWrapperCPP()
280
281 def __del__(self):
282 self._inner.freeAccelStructure()
283
285 self,
286 vertices: Float32[torch.Tensor, "nvert 3"],
287 faces: Int32[torch.Tensor, "nface 3"],
288 ):
289 self._inner.buildAccelStructure(vertices, faces)
build_accel_structure(self, Float32[torch.Tensor, "nvert 3"] vertices, Int32[torch.Tensor, "nface 3"] faces)
Definition ray_optix.py:288
A class for performing ray-mesh intersection tests using OptiX acceleration structure.
Definition ray_optix.py:18
Int32[torch.Tensor, "*b"] intersects_first(self, Float32[torch.Tensor, "*b 3"] origins, Float32[torch.Tensor, "*b 3"] directions)
Find the index of the first intersection for each ray.
Definition ray_optix.py:94
Bool[torch.Tensor, "*b 3"] contains_points(self, Float32[torch.Tensor, "*b 3"] points, Optional[Float32[torch.Tensor, "3"]] check_direction=None)
Check if the points are contained within the mesh.
Definition ray_optix.py:235
__init__(self, **kwargs)
Initialize the RayMeshIntersector class.
Definition ray_optix.py:24
( Tuple intersects_closest[ Bool[torch.Tensor, "*b"], # hit Bool[torch.Tensor, "*b"], # front Int32[torch.Tensor, "*b"], # triangle index Float32[torch.Tensor, "*b 3"], # intersect location Float32[torch.Tensor, "*b 2"], # uv]|Tuple[ Bool[torch.Tensor, "*b"], # hit Bool[torch.Tensor, "h"], # front Int32[torch.Tensor, "h"], # ray index Int32[torch.Tensor, "h"], # triangle index Float32[torch.Tensor, "h 3"], # intersect location Float32[torch.Tensor, "h 2"], # uv:])(self, Float32[torch.Tensor, "*b 3"] origins, Float32[torch.Tensor, "*b 3"] directions, bool stream_compaction=False)
Find the closest intersection for each ray.
Definition ray_optix.py:138
( Tuple[ Int32[torch.Tensor, "h"], # hit triangle indices Int32[torch.Tensor, "h"], # ray indices Float32[torch.Tensor, "h 3"], # hit location]|Tuple[ Int32[torch.Tensor, "h"], Int32[torch.Tensor, "h"]] # hit triangle indices and ray indices) intersects_id(self, Float32[torch.Tensor, "*b 3"] origins, Float32[torch.Tensor, "*b 3"] directions, bool return_locations=False, bool multiple_hits=True)
Find the intersection indices for each ray.
Definition ray_optix.py:206
Tuple[ Float32[torch.Tensor, "h 3"], Int32[torch.Tensor, "h"], Int32[torch.Tensor, "h"]] intersects_location(self, Float32[torch.Tensor, "*b 3"] origins, Float32[torch.Tensor, "*b 3"] directions)
Find the intersection location for each ray.
Definition ray_optix.py:163
Int32[torch.Tensor, "*b 3"] intersects_count(self, Float32[torch.Tensor, "*b 3"] origins, Float32[torch.Tensor, "*b 3"] directions)
Count the number of intersections for each ray.
Definition ray_optix.py:176
update_raw(self, Float32[torch.Tensor, "n 3"] vertices, Int32[torch.Tensor, "f 3"] faces)
Update the raw mesh data.
Definition ray_optix.py:57
Bool[torch.Tensor, "*b"] intersects_any(self, Float32[torch.Tensor, "*b 3"] origins, Float32[torch.Tensor, "*b 3"] directions)
Check if any intersections occur for each ray.
Definition ray_optix.py:81