triro 1.3.0
A Python Ray-Mesh Intersector in OptiX
Loading...
Searching...
No Matches
shaders.cu
Go to the documentation of this file.
1
2#include "LaunchParams.h"
3#include "optix_types.h"
4#include <cuda_device_runtime_api.h>
5#include <optix_device.h>
6#include <tuple>
7
8namespace hmesh {
9
10extern "C" __constant__ LaunchParams launchParams;
11
12__forceinline__ __device__ std::tuple<unsigned int, unsigned int>
13setPayloadPointer(void *p) {
14 unsigned int u0 = (unsigned long long)p & 0xFFFFFFFFllu;
15 unsigned int u1 = ((unsigned long long)p >> 32) & 0xFFFFFFFFllu;
16 return {u0, u1};
17}
18
19template <typename T>
20__forceinline__ __host__ __device__ T *getPayloadPointer() {
21 unsigned int u0 = optixGetPayload_0();
22 unsigned int u1 = optixGetPayload_1();
23 void *p = (void *)(((unsigned long long)u1 << 32) + u0);
24 return (T *)p;
25}
26
27__forceinline__ __device__ void getIndices(int64_t indices[MAX_SIZE_LENGTH], int64_t shape[MAX_SIZE_LENGTH], int idx) {
28 #pragma unroll
29 for (int i = MAX_SIZE_LENGTH - 1; i >= 0; i--) {
30 indices[i] = idx % shape[i];
31 idx /= shape[i];
32 }
33}
34
35__forceinline__ __device__ std::tuple<float3, float3> getRay(int idx) {
36 // corresponding float idx in [0, 3N)
37 int float_idx = idx * 3;
38 // thread index in all dims
39 int64_t indices[MAX_SIZE_LENGTH];
40 getIndices(indices, launchParams.rays.rayShape, float_idx);
41 // index in the flat array
42 int64_t ori_real_idx = 0;
43 int64_t dir_real_idx = 0;
44 #pragma unroll
45 for (int i = 0; i < MAX_SIZE_LENGTH; i++) {
46 ori_real_idx += indices[i] * launchParams.rays.originsStride[i];
47 dir_real_idx += indices[i] * launchParams.rays.directionsStride[i];
48 }
49 // ray info
50 float3 ray_origin;
51 int64_t last_stride = launchParams.rays.originsStride[MAX_SIZE_LENGTH - 1];
52 ray_origin.x = launchParams.rays.origins[ori_real_idx];
53 ray_origin.y = launchParams.rays.origins[ori_real_idx + last_stride];
54 ray_origin.z = launchParams.rays.origins[ori_real_idx + 2 * last_stride];
55 float3 ray_dir;
56 last_stride = launchParams.rays.directionsStride[MAX_SIZE_LENGTH - 1];
57 ray_dir.x = launchParams.rays.directions[dir_real_idx];
58 ray_dir.y = launchParams.rays.directions[dir_real_idx + last_stride];
59 ray_dir.z = launchParams.rays.directions[dir_real_idx + 2 * last_stride];
60 // printf("idx: %d, ray_origin: (%f, %f, %f), ray_dir: (%f, %f, %f), ori_real_idx: %ld, dir_real_idx: %ld, indices: (%ld, %ld, %ld, %ld)\n",
61 // idx, ray_origin.x, ray_origin.y, ray_origin.z, ray_dir.x, ray_dir.y, ray_dir.z, ori_real_idx, dir_real_idx, indices[0], indices[1], indices[2], indices[3]);
62 return {ray_origin, ray_dir};
63}
64
65// intersects_any
66
67extern "C" __global__ void __miss__intersectsAny() {
68 bool *result_pt = getPayloadPointer<bool>();
69 *result_pt = false;
70}
71
72extern "C" __global__ void __anyhit__intersectsAny() {
73 bool *result_pt = getPayloadPointer<bool>();
74 *result_pt = true;
75}
76
77extern "C" __global__ void __raygen__intersectsAny() {
78 // thread index, ranging in [0, N)
79 int idx = optixGetLaunchIndex().x;
80 // intersection result, to be overwritten by the shader
81 bool isect_result = false;
82 // ray info
83 auto [ray_origin, ray_dir] = getRay(idx);
84 // result pointer
85 auto [u0, u1] = setPayloadPointer(&isect_result);
86 optixTrace(launchParams.traversable, ray_origin, ray_dir, 0., 1e7, 0,
87 OptixVisibilityMask(255), OPTIX_RAY_FLAG_NONE, 0, 0, 0, u0, u1);
88 launchParams.results.hit[idx] = isect_result;
89}
90
91// intersects_first
92
93extern "C" __global__ void __miss__intersectsFirst() {
94 int *result_pt = getPayloadPointer<int>();
95 *result_pt = -1;
96}
97
98extern "C" __global__ void __closesthit__intersectsFirst() {
99 int *result_pt = getPayloadPointer<int>();
100 *result_pt = optixGetPrimitiveIndex();
101}
102
103extern "C" __global__ void __raygen__intersectsFirst() {
104 // thread index, ranging in [0, N)
105 int idx = optixGetLaunchIndex().x;
106 // first hit triangle index, to be overwritten by the shader
107 int ch_idx = -1;
108 // ray info
109 auto [ray_origin, ray_dir] = getRay(idx);
110 // result pointer
111 auto [u0, u1] = setPayloadPointer(&ch_idx);
112 optixTrace(launchParams.traversable, ray_origin, ray_dir, 0., 1e7, 0,
113 OptixVisibilityMask(255), OPTIX_RAY_FLAG_DISABLE_ANYHIT, 0, 0, 0,
114 u0, u1);
115 launchParams.results.triIdx[idx] = ch_idx;
116}
117
118// intersects_closest
119
120struct WBData {
121 bool hit;
122 bool front;
123 int triIdx;
124 float3 loc;
125 float2 uv;
126};
127
128extern "C" __global__ void __miss__intersectsClosest() {
129 WBData *result = getPayloadPointer<WBData>();
130 result->hit = false;
131 result->triIdx = -1;
132 result->uv = {0, 0};
133 result->loc = {0, 0, 0};
134 result->front = false;
135}
136
137extern "C" __global__ void __closesthit__intersectsClosest() {
138 WBData *result = getPayloadPointer<WBData>();
139 float2 uv = optixGetTriangleBarycentrics();
140 int triIdx = optixGetPrimitiveIndex();
141 float3 verts[3];
142 optixGetTriangleVertexData(launchParams.traversable, triIdx, 0, 0, verts);
143 float3 isectLoc = {
144 uv.x * verts[1].x + uv.y * verts[2].x + (1 - uv.x - uv.y) * verts[0].x,
145 uv.x * verts[1].y + uv.y * verts[2].y + (1 - uv.x - uv.y) * verts[0].y,
146 uv.x * verts[1].z + uv.y * verts[2].z + (1 - uv.x - uv.y) * verts[0].z};
147
148 result->triIdx = triIdx;
149 result->uv = {1 - uv.x - uv.y, uv.x};
150 result->hit = true;
151 result->front = optixIsFrontFaceHit();
152 result->loc = isectLoc;
153}
154
155extern "C" __global__ void __raygen__intersectsClosest() {
156 // thread index, ranging in [0, N)
157 int idx = optixGetLaunchIndex().x;
158 WBData wbdata;
159 // ray info
160 auto [ray_origin, ray_dir] = getRay(idx);
161 // result pointer
162 auto [u0, u1] = setPayloadPointer(&wbdata);
163 optixTrace(launchParams.traversable, ray_origin, ray_dir, 0., 1e7, 0,
164 OptixVisibilityMask(255), OPTIX_RAY_FLAG_DISABLE_ANYHIT, 0, 0, 0,
165 u0, u1);
166 // write back to the buffers
167 launchParams.results.hit[idx] = wbdata.hit;
168 launchParams.results.front[idx] = wbdata.front;
169 launchParams.results.location[idx] = wbdata.loc;
170 launchParams.results.triIdx[idx] = wbdata.triIdx;
171 launchParams.results.uv[idx] = wbdata.uv;
172}
173
174// intersects_location
175
176extern "C" __global__ void __anyhit__intersectsCount() {
177 int *hitCount = getPayloadPointer<int>();
178 // it seems we don't need atomic ops as they are not parallel
179 (*hitCount)++;
180 optixIgnoreIntersection();
181}
182
183extern "C" __global__ void __raygen__intersectsCount() {
184 // thread index, ranging in [0, N)
185 int idx = optixGetLaunchIndex().x;
186 int hitCount = 0;
187 // ray info
188 auto [ray_origin, ray_dir] = getRay(idx);
189 // result pointer
190 auto [u0, u1] = setPayloadPointer(&hitCount);
191 optixTrace(launchParams.traversable, ray_origin, ray_dir, 0., 1e7, 0,
192 OptixVisibilityMask(255), OPTIX_RAY_FLAG_NONE, 0, 0, 0, u0, u1);
193 launchParams.results.hitCount[idx] = hitCount;
194}
195
196struct IsectLocWBTerm {
197 int triIdx;
198 float3 loc;
199}; // 16B
200
201struct IsectLocPayload {
202 IsectLocWBTerm terms[MAX_ANYHIT_SIZE];
203 int hitCount;
204 int globalIdx;
205};
206
207extern "C" __global__ void __anyhit__intersectsLocation() {
208 IsectLocPayload *payload = getPayloadPointer<IsectLocPayload>();
209 if (payload->hitCount >= MAX_ANYHIT_SIZE)
210 return;
211 int localidx = payload->hitCount;
212 payload->hitCount++;
213 float2 uv = optixGetTriangleBarycentrics();
214 int triIdx = optixGetPrimitiveIndex();
215 float3 verts[3];
216 optixGetTriangleVertexData(launchParams.traversable, triIdx, 0, 0, verts);
217 float3 isectLoc = {
218 uv.x * verts[1].x + uv.y * verts[2].x + (1 - uv.x - uv.y) * verts[0].x,
219 uv.x * verts[1].y + uv.y * verts[2].y + (1 - uv.x - uv.y) * verts[0].y,
220 uv.x * verts[1].z + uv.y * verts[2].z + (1 - uv.x - uv.y) * verts[0].z};
221 payload->terms[localidx].loc = isectLoc;
222 payload->terms[localidx].triIdx = triIdx;
223 optixIgnoreIntersection();
224}
225
226extern "C" __global__ void __raygen__intersectsLocation() {
227 // thread index, ranging in [0, N)
228 int idx = optixGetLaunchIndex().x;
229 int hitCount = launchParams.rays.hitCounts[idx];
230 int globalIdx = launchParams.rays.globalIdx[idx];
231 IsectLocPayload payload = {};
232 payload.hitCount = 0;
233 payload.globalIdx = globalIdx;
234 // ray info
235 auto [ray_origin, ray_dir] = getRay(idx);
236 // result pointer
237 auto [u0, u1] = setPayloadPointer(&payload);
238 optixTrace(launchParams.traversable, ray_origin, ray_dir, 0., 1e7, 0,
239 OptixVisibilityMask(255), OPTIX_RAY_FLAG_NONE, 0, 0, 0, u0, u1);
240 // fill global buffer
241 for (int i = 0; i < hitCount; i++) {
242 launchParams.results.rayIdx[globalIdx + i] = idx;
243 launchParams.results.triIdx[globalIdx + i] = payload.terms[i].triIdx;
244 launchParams.results.location[globalIdx + i] = payload.terms[i].loc;
245 }
246}
247
248} // namespace hmesh