3from typing
import Tuple
4from jaxtyping
import Float32, Bool, Int32
6import torch.utils.cpp_extension
14 Get the triro module by compiling it if necessary and importing it.
20 if triro_module
is not None:
24 source_files = [
'base.cpp',
'binding.cpp',
'ray.cpp']
26 optix_install_dir = os.environ[
'OptiX_INSTALL_DIR']
28 cflags = [f
'-I{optix_install_dir}/include']
32 source_paths = [os.path.join(os.path.dirname(__file__), fn)
for fn
in source_files]
35 cflags += [
'/DNOMINMAX']
36 ldflags = [
'/DLL',
'cuda.lib',
'cudart.lib',
'Advapi32.lib']
37 torch.utils.cpp_extension.load(
41 extra_ldflags=ldflags,
45 triro_module = importlib.import_module(
"triro")
58 Create an OptiX context.
65 Create an OptiX module.
72 Create OptiX pipelines.
79 Build the SBTs (Shader Binding Tables).
85 accel_structure: OptixAccelStructureWrapper,
86 origins: Float32[torch.Tensor,
"*b 3"],
87 dirs: Float32[torch.Tensor,
"*b 3"],
88) -> Bool[torch.Tensor,
"*b"]:
90 Check if any ray intersects with the acceleration structure.
93 accel_structure: The acceleration structure.
94 origins: The origins of the rays.
95 dirs: The directions of the rays.
98 A boolean tensor indicating if each ray intersects with the acceleration structure.
100 return get_module().intersectsAny(accel_structure._inner, origins, dirs)
104 accel_structure: OptixAccelStructureWrapper,
105 origins: Float32[torch.Tensor,
"*b 3"],
106 dirs: Float32[torch.Tensor,
"*b 3"],
107) -> Int32[torch.Tensor,
"*b"]:
109 Find the index of the first intersection for each ray.
112 accel_structure: The acceleration structure.
113 origins: The origins of the rays.
114 dirs: The directions of the rays.
117 An integer tensor indicating the index of the first intersection for each ray.
119 return get_module().intersectsFirst(accel_structure._inner, origins, dirs)
123 accel_structure: OptixAccelStructureWrapper,
124 origins: Float32[torch.Tensor,
"*b 3"],
125 dirs: Float32[torch.Tensor,
"*b 3"],
127 Bool[torch.Tensor,
"*b"],
128 Bool[torch.Tensor,
"*b"],
129 Int32[torch.Tensor,
"*b"],
130 Float32[torch.Tensor,
"*b 3"],
131 Float32[torch.Tensor,
"*b 2"],
134 Find the closest intersection for each ray.
137 accel_structure: The acceleration structure.
138 origins: The origins of the rays.
139 dirs: The directions of the rays.
142 A tuple containing the following tensors:
143 - A boolean tensor indicating if each ray hits an object.
144 - A boolean tensor indicating if each ray hits the front face of an object.
145 - An integer tensor indicating the index of the triangle that each ray intersects with.
146 - A float tensor indicating the location of the intersection for each ray.
147 - A float tensor indicating the UV coordinates of the intersection for each ray.
149 return get_module().intersectsClosest(accel_structure._inner, origins, dirs)
153 accel_structure: OptixAccelStructureWrapper,
154 origins: Float32[torch.Tensor,
"*b 3"],
155 dirs: Float32[torch.Tensor,
"*b 3"],
156) -> Int32[torch.Tensor,
"*b"]:
158 Count the number of intersections for each ray.
161 accel_structure: The acceleration structure.
162 origins: The origins of the rays.
163 dirs: The directions of the rays.
166 An integer tensor indicating the number of intersections for each ray.
168 return get_module().intersectsCount(accel_structure._inner, origins, dirs)
172 accel_structure: OptixAccelStructureWrapper,
173 origins: Float32[torch.Tensor,
"*b 3"],
174 dirs: Float32[torch.Tensor,
"*b 3"],
176 Float32[torch.Tensor,
"h 3"], Int32[torch.Tensor,
"h"], Int32[torch.Tensor,
"h"]
179 Find the location of intersections for each ray.
182 accel_structure: The acceleration structure.
183 origins: The origins of the rays.
184 dirs: The directions of the rays.
187 A tuple containing the following tensors:
188 - A float tensor indicating the location of the intersection for each ray.
189 - The index of the ray that had the intersection.
190 - An integer tensor indicating the index of the instance that each ray intersects with.
192 return get_module().intersectsLocation(accel_structure._inner, origins, dirs)
Int32[torch.Tensor, "*b"] intersects_count(OptixAccelStructureWrapper accel_structure, Float32[torch.Tensor, "*b 3"] origins, Float32[torch.Tensor, "*b 3"] dirs)
Tuple[ Bool[torch.Tensor, "*b"], # hit Bool[torch.Tensor, "*b"], # front Int32[torch.Tensor, "*b"], # triangle index Float32[torch.Tensor, "*b 3"], # intersect location Float32[torch.Tensor, "*b 2"], # uv] intersects_closest(OptixAccelStructureWrapper accel_structure, Float32[torch.Tensor, "*b 3"] origins, Float32[torch.Tensor, "*b 3"] dirs)
Int32[torch.Tensor, "*b"] intersects_first(OptixAccelStructureWrapper accel_structure, Float32[torch.Tensor, "*b 3"] origins, Float32[torch.Tensor, "*b 3"] dirs)
Tuple[ Float32[torch.Tensor, "h 3"], Int32[torch.Tensor, "h"], Int32[torch.Tensor, "h"]] intersects_location(OptixAccelStructureWrapper accel_structure, Float32[torch.Tensor, "*b 3"] origins, Float32[torch.Tensor, "*b 3"] dirs)
Bool[torch.Tensor, "*b"] intersects_any(OptixAccelStructureWrapper accel_structure, Float32[torch.Tensor, "*b 3"] origins, Float32[torch.Tensor, "*b 3"] dirs)