triro 1.3.0
A Python Ray-Mesh Intersector in OptiX
Loading...
Searching...
No Matches
ray.cpp
Go to the documentation of this file.
1
8#include "ray.h"
9#include "ATen/core/TensorBody.h"
10#include "ATen/ops/where.h"
11#include "CUDABuffer.h"
12#include "LaunchParams.h"
13#include "base.h"
14#include "c10/core/Layout.h"
15#include "c10/core/ScalarType.h"
16#include "c10/core/TensorOptions.h"
17#include "c10/util/ArrayRef.h"
18#include "optix8.h"
19#include "optix_host.h"
20#include "optix_types.h"
21#include "sbtdef.h"
22#include "type.h"
23#include <limits>
24
25namespace hmesh {
26
28 torch::Tensor faces) {
29 OptixAccelBuildOptions buildOptions = {};
30 OptixBuildInput buildInput = {};
31
32 // CUdeviceptr tempBuffer, outputBuffer;
33 size_t tempBufferSizeInBytes, outputBufferSizeInBytes;
34
35 buildOptions.buildFlags = OPTIX_BUILD_FLAG_NONE |
36 OPTIX_BUILD_FLAG_ALLOW_COMPACTION |
37 OPTIX_BUILD_FLAG_ALLOW_RANDOM_VERTEX_ACCESS;
38 buildOptions.operation = OPTIX_BUILD_OPERATION_BUILD;
39 buildOptions.motionOptions.numKeys = 0;
40
41 CUdeviceptr pVert = (CUdeviceptr)vertices.data_ptr();
42 CUdeviceptr pFace = (CUdeviceptr)faces.data_ptr();
43
44 buildInput.type = OPTIX_BUILD_INPUT_TYPE_TRIANGLES;
45 buildInput.triangleArray.vertexBuffers = &pVert;
46 buildInput.triangleArray.numVertices = vertices.size(0);
47 buildInput.triangleArray.vertexFormat = OPTIX_VERTEX_FORMAT_FLOAT3;
48 buildInput.triangleArray.vertexStrideInBytes = sizeof(vec3f);
49 buildInput.triangleArray.indexBuffer = pFace;
50 buildInput.triangleArray.numIndexTriplets = faces.size(0);
51 buildInput.triangleArray.indexFormat = OPTIX_INDICES_FORMAT_UNSIGNED_INT3;
52 buildInput.triangleArray.indexStrideInBytes = sizeof(vec3i);
53 buildInput.triangleArray.preTransform = 0;
54
55 buildInput.triangleArray.numSbtRecords = 1;
56 buildInput.triangleArray.sbtIndexOffsetBuffer = 0;
57 buildInput.triangleArray.sbtIndexOffsetSizeInBytes = 0;
58 buildInput.triangleArray.sbtIndexOffsetStrideInBytes = 0;
59
60 uint32_t triangleBuildFlags =
61 OPTIX_GEOMETRY_FLAG_REQUIRE_SINGLE_ANYHIT_CALL;
62 buildInput.triangleArray.flags = &triangleBuildFlags;
63
64 OptixAccelBufferSizes bufferSizes = {};
65 OPTIX_CHECK(optixAccelComputeMemoryUsage(optixContext, &buildOptions,
66 &buildInput, 1, &bufferSizes));
67
68 CUDABuffer tempBuffer;
69 CUDABuffer accelStructureBuffer;
70 accelStructureBuffer.alloc(bufferSizes.outputSizeInBytes);
71 tempBuffer.alloc(bufferSizes.tempSizeInBytes);
72
73 CUDABuffer compactedSizeBuffer;
74 compactedSizeBuffer.alloc(sizeof(uint64_t));
75 OptixAccelEmitDesc emitDesc;
76 emitDesc.type = OPTIX_PROPERTY_TYPE_COMPACTED_SIZE;
77 emitDesc.result = compactedSizeBuffer.d_pointer();
78
79 OPTIX_CHECK(optixAccelBuild(
80 optixContext, cuStream, &buildOptions, &buildInput, 1,
81 (CUdeviceptr)tempBuffer.d_ptr, tempBuffer.sizeInBytes,
82 (CUdeviceptr)accelStructureBuffer.d_ptr,
83 accelStructureBuffer.sizeInBytes, &asHandle, &emitDesc, 1));
84
86
87 uint64_t compactedSize;
88 compactedSizeBuffer.download(&compactedSize, 1);
89 asBuffer.resize(compactedSize);
90
91 OPTIX_CHECK(optixAccelCompact(optixContext, cuStream, asHandle,
92 asBuffer.d_pointer(), compactedSize,
93 &asHandle));
94
96
97 compactedSizeBuffer.free();
98 tempBuffer.free();
99 accelStructureBuffer.free();
100}
101
103
104template <typename... Ts> inline bool tensorInputCheck(Ts... ts) {
105 bool valid = true;
106 (
107 [&] {
108 if (!ts.is_cuda()) {
109 std::cerr << "error in file " << __FILE__ << " line "
110 << __LINE__
111 << ": input tensors must reside in cuda device.\n";
112 valid = false;
113 }
114 if (ts.layout() != torch::kStrided) {
115 std::cerr << "error in file " << __FILE__ << " line "
116 << __LINE__
117 << ": input tensor layout must be torch::kStrided.\n";
118 valid = false;
119 }
120 }(),
121 ...);
122 return valid;
123}
124
125inline std::vector<int64_t> removeLastDim(const c10::IntArrayRef dims) {
126 auto ref = dims.vec();
127 ref.pop_back();
128 return ref;
129}
130
131inline size_t prod(const std::vector<int64_t> &dims) {
132 size_t p = 1;
133 for (auto s : dims)
134 p *= s;
135 return p;
136}
137
138inline std::vector<int64_t> changeLastDim(const c10::IntArrayRef dims,
139 size_t value) {
140 std::vector<int64_t> dimsVec;
141 for (auto s : dims)
142 dimsVec.push_back(s);
143 *(dimsVec.end() - 1) = value;
144 return dimsVec;
145}
146
147template <typename T> inline T *data_ptr(const torch::Tensor &t) {
148 return (T *)t.data_ptr();
149}
150
151template <typename T>
152void fillArray(T *dst, c10::ArrayRef<T> src, T defaultValue) {
153 int i = 0;
154 const int src_size = src.size();
155 for (; i < MAX_SIZE_LENGTH - src_size; i++)
156 dst[i] = defaultValue;
157 for (; i < MAX_SIZE_LENGTH; i++)
158 dst[i] = src[i + src_size - MAX_SIZE_LENGTH];
159}
160
162 const torch::Tensor &origins,
163 const torch::Tensor &directions) {
164 if (!tensorInputCheck(origins, directions))
165 return {};
166 // output buffer
167 auto options =
168 torch::TensorOptions().dtype(torch::kBool).device(torch::kCUDA);
169 auto resultSize = removeLastDim(origins.sizes());
170 auto nray = prod(resultSize);
171 auto result = torch::empty(resultSize, options);
172 // fill launch params
173 LaunchParams lp = {};
174 lp.rays.origins = data_ptr<float>(origins);
175 lp.rays.directions = data_ptr<float>(directions);
176 lp.rays.nray = nray;
177 fillArray(lp.rays.rayShape, origins.sizes(), std::numeric_limits<int64_t>::max());
178 fillArray(lp.rays.originsStride, origins.strides(), (int64_t) 0);
179 fillArray(lp.rays.directionsStride, directions.strides(), (int64_t) 0);
180 lp.traversable = as.asHandle;
181 lp.results.hit = data_ptr<bool>(result);
182 CUDABuffer lpBuffer;
183 lpBuffer.alloc_and_upload(&lp, 1);
185 lpBuffer.d_pointer(), sizeof(lp),
187 lpBuffer.free();
188 return result;
189}
190
192 const torch::Tensor &origins,
193 const torch::Tensor &directions) {
194 if (!tensorInputCheck(origins, directions))
195 return {};
196 // output buffer
197 auto options =
198 torch::TensorOptions().dtype(torch::kInt).device(torch::kCUDA);
199 auto resultSize = removeLastDim(origins.sizes());
200 auto nray = prod(resultSize);
201 auto result = torch::empty(resultSize, options);
202 // fill launch params
203 LaunchParams lp = {};
204 lp.rays.origins = data_ptr<float>(origins);
205 lp.rays.directions = data_ptr<float>(directions);
206 lp.rays.nray = nray;
207 fillArray(lp.rays.rayShape, origins.sizes(), std::numeric_limits<int64_t>::max());
208 fillArray(lp.rays.originsStride, origins.strides(), (int64_t) 0);
209 fillArray(lp.rays.directionsStride, directions.strides(), (int64_t) 0);
210 lp.traversable = as.asHandle;
211 lp.results.triIdx = data_ptr<int>(result);
212 CUDABuffer lpBuffer;
213 lpBuffer.alloc_and_upload(&lp, 1);
215 lpBuffer.d_pointer(), sizeof(lp),
217 lpBuffer.free();
218 return result;
219}
220
231std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor,
232 torch::Tensor>
234 torch::Tensor directions) {
235 if (!tensorInputCheck(origins, directions))
236 return {};
237 // output buffers
238 // hitmask buffer
239 auto hitbufOptions =
240 torch::TensorOptions().dtype(torch::kBool).device(torch::kCUDA);
241 auto hitbufSize = removeLastDim(origins.sizes());
242 auto hitbuf = torch::empty(hitbufSize, hitbufOptions);
243 // front hit buffer
244 auto frontbuf = torch::empty(hitbufSize, hitbufOptions);
245 // triangle index buffer
246 auto tibufOptions =
247 torch::TensorOptions().dtype(torch::kInt).device(torch::kCUDA);
248 auto tibufSize = removeLastDim(origins.sizes());
249 auto tibuf = torch::empty(tibufSize, tibufOptions);
250 // intersect location buffer
251 auto locbufOptions =
252 torch::TensorOptions().dtype(torch::kFloat).device(torch::kCUDA);
253 auto locbufSize = changeLastDim(origins.sizes(), 3);
254 auto locbuf = torch::empty(locbufSize, locbufOptions);
255 // uv buffer
256 auto uvbufOptions =
257 torch::TensorOptions().dtype(torch::kFloat).device(torch::kCUDA);
258 auto uvbufSize = changeLastDim(origins.sizes(), 2);
259 auto uvbuf = torch::empty(uvbufSize, uvbufOptions);
260 auto nray = prod(hitbufSize);
261
262 // fill and upload launchParams
263 LaunchParams lp = {};
264 lp.rays.nray = nray;
265 lp.rays.origins = data_ptr<float>(origins);
266 lp.rays.directions = data_ptr<float>(directions);
267 fillArray(lp.rays.rayShape, origins.sizes(), std::numeric_limits<int64_t>::max());
268 fillArray(lp.rays.originsStride, origins.strides(), (int64_t) 0);
269 fillArray(lp.rays.directionsStride, directions.strides(), (int64_t) 0);
270
271 lp.results.hit = data_ptr<bool>(hitbuf);
272 lp.results.location = data_ptr<float3>(locbuf);
273 lp.results.triIdx = data_ptr<int>(tibuf);
274 lp.results.uv = data_ptr<float2>(uvbuf);
275 lp.results.front = data_ptr<bool>(frontbuf);
276
277 lp.traversable = as.asHandle;
278
279 CUDABuffer lpBuffer;
280 lpBuffer.alloc_and_upload(&lp, 1);
281
282 // 启动!
284 lpBuffer.d_pointer(), sizeof(lp),
285 &sbts[SBTType::INTERSECTS_CLOSEST], nray, 1, 1);
286
287 lpBuffer.free();
288 return {hitbuf, frontbuf, tibuf, locbuf, uvbuf};
289}
290
292 torch::Tensor origins, torch::Tensor directions) {
293 if (!tensorInputCheck(origins, directions))
294 return {};
295 // first pass - get the intersection count
296 auto hitCountBufSize = removeLastDim(origins.sizes());
297 auto hitCountBufOptions =
298 torch::TensorOptions().dtype(torch::kInt).device(torch::kCUDA);
299 auto hitCountBuf =
300 torch::zeros(hitCountBufSize, hitCountBufOptions).contiguous();
301 auto nray = prod(hitCountBufSize);
302
303 LaunchParams lp = {};
304 lp.rays.nray = nray;
305 lp.rays.origins = data_ptr<float>(origins);
306 lp.rays.directions = data_ptr<float>(directions);
307 fillArray(lp.rays.rayShape, origins.sizes(), std::numeric_limits<int64_t>::max());
308 fillArray(lp.rays.originsStride, origins.strides(), (int64_t) 0);
309 fillArray(lp.rays.directionsStride, directions.strides(), (int64_t) 0);
310 lp.results.hitCount = data_ptr<int>(hitCountBuf);
311 lp.traversable = as.asHandle;
312
313 CUDABuffer lpBuffer;
314 lpBuffer.alloc_and_upload(&lp, 1);
315
317 lpBuffer.d_pointer(), sizeof(lp),
318 &sbts[SBTType::INTERSECTS_COUNT], nray, 1, 1);
319
320 lpBuffer.free();
321 return hitCountBuf;
322}
323
324std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
326 torch::Tensor directions) {
327 if (!tensorInputCheck(origins, directions))
328 return {};
329 // first pass - get the intersection count
330 auto hitCountBuf = intersectsCount(as, origins, directions);
331
332 // second pass
333 hitCountBuf = hitCountBuf.flatten();
334 hitCountBuf = torch::where(hitCountBuf <= MAX_ANYHIT_SIZE, hitCountBuf,
336 auto globalIdxBuf = hitCountBuf.cumsum(0);
337 auto globalIdxBufOptions =
338 torch::TensorOptions().dtype(torch::kInt).device(torch::kCUDA);
339 auto nhits = globalIdxBuf[-1].item<int>();
340 globalIdxBuf = torch::cat({torch::zeros({1}, globalIdxBufOptions),
341 torch::slice(hitCountBuf, 0, 0, -1)});
342 // hit location
343 auto locbufOptions =
344 torch::TensorOptions().dtype(torch::kFloat).device(torch::kCUDA);
345 auto locbuf = torch::empty({nhits, 3}, locbufOptions);
346 auto idxbufOptions =
347 torch::TensorOptions().dtype(torch::kInt).device(torch::kCUDA);
348 auto tibuf = torch::empty({nhits}, idxbufOptions);
349 auto ribuf = torch::empty({nhits}, idxbufOptions);
350
351 auto nray = prod(removeLastDim(origins.sizes()));
352
353 LaunchParams lp = {};
354 lp.traversable = as.asHandle;
355 lp.rays.nray = nray;
356 lp.rays.origins = data_ptr<float>(origins);
357 lp.rays.directions = data_ptr<float>(directions);
358 lp.rays.hitCounts = data_ptr<int>(hitCountBuf);
359 lp.rays.globalIdx = data_ptr<int>(globalIdxBuf);
360 fillArray(lp.rays.rayShape, origins.sizes(), std::numeric_limits<int64_t>::max());
361 fillArray(lp.rays.originsStride, origins.strides(), (int64_t) 0);
362 fillArray(lp.rays.directionsStride, directions.strides(), (int64_t) 0);
363 lp.results.hitCount = data_ptr<int>(hitCountBuf);
364 lp.results.location = data_ptr<float3>(locbuf);
365 lp.results.triIdx = data_ptr<int>(tibuf);
366 lp.results.rayIdx = data_ptr<int>(ribuf);
367
368 CUDABuffer lpBuffer;
369 lpBuffer.alloc_and_upload(&lp, 1);
370
372 cuStream, lpBuffer.d_pointer(), sizeof(lp),
373 &sbts[SBTType::INTERSECTS_LOCATION], nray, 1, 1);
374
375 lpBuffer.free();
376 return {locbuf, ribuf, tibuf};
377}
378
379} // namespace hmesh
Definition base.cpp:13
OptixShaderBindingTable sbts[SBTType::count]
Definition base.cpp:22
constexpr int MAX_ANYHIT_SIZE
Definition LaunchParams.h:8
T * data_ptr(const torch::Tensor &t)
Definition ray.cpp:147
constexpr int MAX_SIZE_LENGTH
Definition LaunchParams.h:9
std::tuple< torch::Tensor, torch::Tensor, torch::Tensor > intersectsLocation(OptixAccelStructureWrapperCPP as, torch::Tensor origins, torch::Tensor directions)
Definition ray.cpp:325
torch::Tensor intersectsFirst(OptixAccelStructureWrapperCPP as, const torch::Tensor &origins, const torch::Tensor &dirs)
Definition ray.cpp:191
OptixPipeline optixPipelines[SBTType::count]
Definition base.cpp:23
std::tuple< torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor > intersectsClosest(OptixAccelStructureWrapperCPP as, torch::Tensor origins, torch::Tensor directions)
Find if ray hits any triangle and return ray index, triangle index, hit location and uv.
Definition ray.cpp:233
size_t prod(const std::vector< int64_t > &dims)
Definition ray.cpp:131
torch::Tensor intersectsCount(OptixAccelStructureWrapperCPP as, torch::Tensor origins, torch::Tensor directions)
Definition ray.cpp:291
std::vector< int64_t > removeLastDim(const c10::IntArrayRef dims)
Definition ray.cpp:125
void fillArray(T *dst, c10::ArrayRef< T > src, T defaultValue)
Definition ray.cpp:152
OptixDeviceContext optixContext
Definition base.cpp:17
std::vector< int64_t > changeLastDim(const c10::IntArrayRef dims, size_t value)
Definition ray.cpp:138
torch::Tensor intersectsAny(OptixAccelStructureWrapperCPP as, const torch::Tensor &origins, const torch::Tensor &dirs)
Definition ray.cpp:161
CUstream cuStream
Definition base.cpp:16
vec< float, 3 > vec3f
Definition type.h:12
vec< int32_t, 3 > vec3i
Definition type.h:11
bool tensorInputCheck(Ts... ts)
Definition ray.cpp:104
#define OPTIX_CHECK(call)
Definition optix8.h:41
#define CUDA_SYNC_CHECK()
Definition optix8.h:51
@ INTERSECTS_LOCATION
Definition sbtdef.h:14
@ INTERSECTS_ANY
Definition sbtdef.h:10
@ INTERSECTS_FIRST
Definition sbtdef.h:11
@ INTERSECTS_COUNT
Definition sbtdef.h:13
@ INTERSECTS_CLOSEST
Definition sbtdef.h:12
CUdeviceptr d_pointer() const
Definition CUDABuffer.h:29
void free()
free allocated memory
Definition CUDABuffer.h:46
void resize(size_t size)
re-size buffer to given number of bytes
Definition CUDABuffer.h:32
void download(T *t, size_t count)
Definition CUDABuffer.h:69
void alloc(size_t size)
allocate to given number of bytes
Definition CUDABuffer.h:39
void alloc_and_upload(const std::vector< T > &vt)
Definition CUDABuffer.h:52
float3 * location
OptixTraversableHandle traversable
void buildAccelStructure(torch::Tensor vertices, torch::Tensor faces)
Definition ray.cpp:27
OptixTraversableHandle asHandle
Definition ray.h:12
int64_t rayShape[MAX_SIZE_LENGTH]
int64_t directionsStride[MAX_SIZE_LENGTH]
int64_t originsStride[MAX_SIZE_LENGTH]