2#include "LaunchParams.h"
3#include "optix_types.h"
4#include <cuda_device_runtime_api.h>
5#include <optix_device.h>
10extern "C" __constant__ LaunchParams launchParams;
12__forceinline__ __device__ std::tuple<unsigned int, unsigned int>
13setPayloadPointer(void *p) {
14 unsigned int u0 = (unsigned long long)p & 0xFFFFFFFFllu;
15 unsigned int u1 = ((unsigned long long)p >> 32) & 0xFFFFFFFFllu;
20__forceinline__ __host__ __device__ T *getPayloadPointer() {
21 unsigned int u0 = optixGetPayload_0();
22 unsigned int u1 = optixGetPayload_1();
23 void *p = (void *)(((unsigned long long)u1 << 32) + u0);
27__forceinline__ __device__ void getIndices(int64_t indices[MAX_SIZE_LENGTH], int64_t shape[MAX_SIZE_LENGTH], int idx) {
29 for (int i = MAX_SIZE_LENGTH - 1; i >= 0; i--) {
30 indices[i] = idx % shape[i];
35__forceinline__ __device__ std::tuple<float3, float3> getRay(int idx) {
36 // corresponding float idx in [0, 3N)
37 int float_idx = idx * 3;
38 // thread index in all dims
39 int64_t indices[MAX_SIZE_LENGTH];
40 getIndices(indices, launchParams.rays.rayShape, float_idx);
41 // index in the flat array
42 int64_t ori_real_idx = 0;
43 int64_t dir_real_idx = 0;
45 for (int i = 0; i < MAX_SIZE_LENGTH; i++) {
46 ori_real_idx += indices[i] * launchParams.rays.originsStride[i];
47 dir_real_idx += indices[i] * launchParams.rays.directionsStride[i];
51 int64_t last_stride = launchParams.rays.originsStride[MAX_SIZE_LENGTH - 1];
52 ray_origin.x = launchParams.rays.origins[ori_real_idx];
53 ray_origin.y = launchParams.rays.origins[ori_real_idx + last_stride];
54 ray_origin.z = launchParams.rays.origins[ori_real_idx + 2 * last_stride];
56 last_stride = launchParams.rays.directionsStride[MAX_SIZE_LENGTH - 1];
57 ray_dir.x = launchParams.rays.directions[dir_real_idx];
58 ray_dir.y = launchParams.rays.directions[dir_real_idx + last_stride];
59 ray_dir.z = launchParams.rays.directions[dir_real_idx + 2 * last_stride];
60 // printf("idx: %d, ray_origin: (%f, %f, %f), ray_dir: (%f, %f, %f), ori_real_idx: %ld, dir_real_idx: %ld, indices: (%ld, %ld, %ld, %ld)\n",
61 // idx, ray_origin.x, ray_origin.y, ray_origin.z, ray_dir.x, ray_dir.y, ray_dir.z, ori_real_idx, dir_real_idx, indices[0], indices[1], indices[2], indices[3]);
62 return {ray_origin, ray_dir};
67extern "C" __global__ void __miss__intersectsAny() {
68 bool *result_pt = getPayloadPointer<bool>();
72extern "C" __global__ void __anyhit__intersectsAny() {
73 bool *result_pt = getPayloadPointer<bool>();
77extern "C" __global__ void __raygen__intersectsAny() {
78 // thread index, ranging in [0, N)
79 int idx = optixGetLaunchIndex().x;
80 // intersection result, to be overwritten by the shader
81 bool isect_result = false;
83 auto [ray_origin, ray_dir] = getRay(idx);
85 auto [u0, u1] = setPayloadPointer(&isect_result);
86 optixTrace(launchParams.traversable, ray_origin, ray_dir, 0., 1e7, 0,
87 OptixVisibilityMask(255), OPTIX_RAY_FLAG_NONE, 0, 0, 0, u0, u1);
88 launchParams.results.hit[idx] = isect_result;
93extern "C" __global__ void __miss__intersectsFirst() {
94 int *result_pt = getPayloadPointer<int>();
98extern "C" __global__ void __closesthit__intersectsFirst() {
99 int *result_pt = getPayloadPointer<int>();
100 *result_pt = optixGetPrimitiveIndex();
103extern "C" __global__ void __raygen__intersectsFirst() {
104 // thread index, ranging in [0, N)
105 int idx = optixGetLaunchIndex().x;
106 // first hit triangle index, to be overwritten by the shader
109 auto [ray_origin, ray_dir] = getRay(idx);
111 auto [u0, u1] = setPayloadPointer(&ch_idx);
112 optixTrace(launchParams.traversable, ray_origin, ray_dir, 0., 1e7, 0,
113 OptixVisibilityMask(255), OPTIX_RAY_FLAG_DISABLE_ANYHIT, 0, 0, 0,
115 launchParams.results.triIdx[idx] = ch_idx;
128extern "C" __global__ void __miss__intersectsClosest() {
129 WBData *result = getPayloadPointer<WBData>();
133 result->loc = {0, 0, 0};
134 result->front = false;
137extern "C" __global__ void __closesthit__intersectsClosest() {
138 WBData *result = getPayloadPointer<WBData>();
139 float2 uv = optixGetTriangleBarycentrics();
140 int triIdx = optixGetPrimitiveIndex();
142 optixGetTriangleVertexData(launchParams.traversable, triIdx, 0, 0, verts);
144 uv.x * verts[1].x + uv.y * verts[2].x + (1 - uv.x - uv.y) * verts[0].x,
145 uv.x * verts[1].y + uv.y * verts[2].y + (1 - uv.x - uv.y) * verts[0].y,
146 uv.x * verts[1].z + uv.y * verts[2].z + (1 - uv.x - uv.y) * verts[0].z};
148 result->triIdx = triIdx;
149 result->uv = {1 - uv.x - uv.y, uv.x};
151 result->front = optixIsFrontFaceHit();
152 result->loc = isectLoc;
155extern "C" __global__ void __raygen__intersectsClosest() {
156 // thread index, ranging in [0, N)
157 int idx = optixGetLaunchIndex().x;
160 auto [ray_origin, ray_dir] = getRay(idx);
162 auto [u0, u1] = setPayloadPointer(&wbdata);
163 optixTrace(launchParams.traversable, ray_origin, ray_dir, 0., 1e7, 0,
164 OptixVisibilityMask(255), OPTIX_RAY_FLAG_DISABLE_ANYHIT, 0, 0, 0,
166 // write back to the buffers
167 launchParams.results.hit[idx] = wbdata.hit;
168 launchParams.results.front[idx] = wbdata.front;
169 launchParams.results.location[idx] = wbdata.loc;
170 launchParams.results.triIdx[idx] = wbdata.triIdx;
171 launchParams.results.uv[idx] = wbdata.uv;
174// intersects_location
176extern "C" __global__ void __anyhit__intersectsCount() {
177 int *hitCount = getPayloadPointer<int>();
178 // it seems we don't need atomic ops as they are not parallel
180 optixIgnoreIntersection();
183extern "C" __global__ void __raygen__intersectsCount() {
184 // thread index, ranging in [0, N)
185 int idx = optixGetLaunchIndex().x;
188 auto [ray_origin, ray_dir] = getRay(idx);
190 auto [u0, u1] = setPayloadPointer(&hitCount);
191 optixTrace(launchParams.traversable, ray_origin, ray_dir, 0., 1e7, 0,
192 OptixVisibilityMask(255), OPTIX_RAY_FLAG_NONE, 0, 0, 0, u0, u1);
193 launchParams.results.hitCount[idx] = hitCount;
196struct IsectLocWBTerm {
201struct IsectLocPayload {
202 IsectLocWBTerm terms[MAX_ANYHIT_SIZE];
207extern "C" __global__ void __anyhit__intersectsLocation() {
208 IsectLocPayload *payload = getPayloadPointer<IsectLocPayload>();
209 if (payload->hitCount >= MAX_ANYHIT_SIZE)
211 int localidx = payload->hitCount;
213 float2 uv = optixGetTriangleBarycentrics();
214 int triIdx = optixGetPrimitiveIndex();
216 optixGetTriangleVertexData(launchParams.traversable, triIdx, 0, 0, verts);
218 uv.x * verts[1].x + uv.y * verts[2].x + (1 - uv.x - uv.y) * verts[0].x,
219 uv.x * verts[1].y + uv.y * verts[2].y + (1 - uv.x - uv.y) * verts[0].y,
220 uv.x * verts[1].z + uv.y * verts[2].z + (1 - uv.x - uv.y) * verts[0].z};
221 payload->terms[localidx].loc = isectLoc;
222 payload->terms[localidx].triIdx = triIdx;
223 optixIgnoreIntersection();
226extern "C" __global__ void __raygen__intersectsLocation() {
227 // thread index, ranging in [0, N)
228 int idx = optixGetLaunchIndex().x;
229 int hitCount = launchParams.rays.hitCounts[idx];
230 int globalIdx = launchParams.rays.globalIdx[idx];
231 IsectLocPayload payload = {};
232 payload.hitCount = 0;
233 payload.globalIdx = globalIdx;
235 auto [ray_origin, ray_dir] = getRay(idx);
237 auto [u0, u1] = setPayloadPointer(&payload);
238 optixTrace(launchParams.traversable, ray_origin, ray_dir, 0., 1e7, 0,
239 OptixVisibilityMask(255), OPTIX_RAY_FLAG_NONE, 0, 0, 0, u0, u1);
240 // fill global buffer
241 for (int i = 0; i < hitCount; i++) {
242 launchParams.results.rayIdx[globalIdx + i] = idx;
243 launchParams.results.triIdx[globalIdx + i] = payload.terms[i].triIdx;
244 launchParams.results.location[globalIdx + i] = payload.terms[i].loc;