triro 1.3.0
A Python Ray-Mesh Intersector in OptiX
Loading...
Searching...
No Matches
base.cpp
Go to the documentation of this file.
1
2#include "CUDABuffer.h"
3#include "embedded/shaders_embedded.h"
4#include "optix8.h"
5#include "optix_host.h"
6#include "optix_types.h"
7#include "sbtdef.h"
8#include <cstddef>
9#include <iostream>
10#include <optix_function_table_definition.h>
11#include <torch/extension.h>
12
13namespace hmesh {
14
15CUcontext cuCtx;
16CUstream cuStream;
17OptixDeviceContext optixContext;
18OptixModule optixModule;
19OptixPipelineCompileOptions pipelineCompileOptions = {};
20
21// SBTs and pipelines for each function
22OptixShaderBindingTable sbts[SBTType::count];
24OptixProgramGroup optixProgramGroups[SBTType::count][3];
25
26static void context_log_cb(unsigned int level, const char *tag,
27 const char *message, void *) {
28 std::cerr << "[" << (int)level << "][" << tag << "]: " << message << "\n";
29}
30
31void initOptix() { optixInit(); }
32
34 cudaStreamCreate(&cuStream);
35 CUresult res = cuCtxGetCurrent(&cuCtx);
36 if (res != CUDA_SUCCESS)
37 std::cerr << "Error getting current CUDA context: error code " << res
38 << "\n";
39 optixDeviceContextCreate(cuCtx, nullptr, &optixContext);
40 optixDeviceContextSetLogCallback(optixContext, context_log_cb, nullptr, 4);
41}
42
44 // create module
45 OptixModuleCompileOptions moduleCompileOptions = {};
46 moduleCompileOptions.maxRegisterCount =
47 OPTIX_COMPILE_DEFAULT_MAX_REGISTER_COUNT;
48 moduleCompileOptions.optLevel = OPTIX_COMPILE_OPTIMIZATION_DEFAULT;
49 moduleCompileOptions.debugLevel = OPTIX_COMPILE_DEBUG_LEVEL_MINIMAL;
50 moduleCompileOptions.numPayloadTypes = 0;
51 moduleCompileOptions.payloadTypes = nullptr;
52
53 pipelineCompileOptions.usesMotionBlur = false;
54 // !must be ALLOW_SINGLE_GAS for only one GAS
55 pipelineCompileOptions.traversableGraphFlags =
56 OPTIX_TRAVERSABLE_GRAPH_FLAG_ALLOW_SINGLE_GAS;
57 pipelineCompileOptions.allowOpacityMicromaps = false;
58 // 0: hitT
59 // 1: hitKind
60 pipelineCompileOptions.numAttributeValues = 2;
61 // 0: resultpointer low 32 bits
62 // 1: resultpointer high 32 bits
63 pipelineCompileOptions.numPayloadValues = 2;
64 pipelineCompileOptions.exceptionFlags = OPTIX_EXCEPTION_FLAG_NONE;
65 pipelineCompileOptions.usesPrimitiveTypeFlags =
66 OPTIX_PRIMITIVE_TYPE_FLAGS_TRIANGLE;
67 pipelineCompileOptions.pipelineLaunchParamsVariableName = "launchParams";
68
69 char logString[2048];
70 size_t logStringSize = 2048;
71 OPTIX_CHECK(optixModuleCreate(optixContext, &moduleCompileOptions,
73 (const char *)shader_code, shader_length,
74 logString, &logStringSize, &optixModule));
75}
76
78 for (int t = 0; t < SBTType::count; t++) {
79 // program affix
80 const std::string &prgName = std::get<0>(programInfos[t]);
81 const std::string raygenName = std::string("__raygen__") + prgName;
82 const std::string anyhitName = std::string("__anyhit__") + prgName;
83 const std::string closesthitName =
84 std::string("__closesthit__") + prgName;
85 const std::string intersectionName =
86 std::string("__intersection__") + prgName;
87 const std::string missName = std::string("__miss__") + prgName;
88 ProgramMask prgMask = std::get<1>(programInfos[t]);
89 // program group descriptors
90 // { RAYGEN, HITGROUP, MISS }
91 OptixProgramGroupDesc pgDescs[3] = {};
92 // raygen program group
93 pgDescs[0].kind = OPTIX_PROGRAM_GROUP_KIND_RAYGEN;
94 pgDescs[0].raygen.module = optixModule;
95 pgDescs[0].raygen.entryFunctionName = raygenName.c_str();
96 // hitgroup program group
97 pgDescs[1].kind = OPTIX_PROGRAM_GROUP_KIND_HITGROUP;
98 if (prgMask & PRG_AH) {
99 pgDescs[1].hitgroup.moduleAH = optixModule;
100 pgDescs[1].hitgroup.entryFunctionNameAH = anyhitName.c_str();
101 }
102 if (prgMask & PRG_CH) {
103 pgDescs[1].hitgroup.moduleCH = optixModule;
104 pgDescs[1].hitgroup.entryFunctionNameCH = closesthitName.c_str();
105 }
106 if (prgMask & PRG_IS) {
107 pgDescs[1].hitgroup.moduleIS = optixModule;
108 pgDescs[1].hitgroup.entryFunctionNameIS = intersectionName.c_str();
109 }
110 // miss program group
111 pgDescs[2].kind = OPTIX_PROGRAM_GROUP_KIND_MISS;
112 if (prgMask & PRG_MS) {
113 pgDescs[2].miss.module = optixModule;
114 pgDescs[2].miss.entryFunctionName = missName.c_str();
115 }
116 // program group options
117 OptixProgramGroupOptions pgOptions[3] = {};
118 // create program group
119 char logString[2048];
120 size_t logStringSize = 2048;
121 OptixProgramGroup *pg = optixProgramGroups[t];
122 OPTIX_CHECK(optixProgramGroupCreate(optixContext, pgDescs, 3, pgOptions,
123 logString, &logStringSize, pg));
124 // create pipeline
125 OptixPipelineLinkOptions pipelineLinkOptions = {};
126 pipelineLinkOptions.maxTraceDepth = 1;
127
128 OPTIX_CHECK(optixPipelineCreate(optixContext, &pipelineCompileOptions,
129 &pipelineLinkOptions, pg, 3, logString,
130 &logStringSize, &optixPipelines[t]));
131 }
132}
133
134void buildSBT() {
135 for (int t = 0; t < SBTType::count; t++) {
136 // create SBT header
137 SBTRecordEmpty tmpRec;
138 CUDABuffer rgRecDevice;
139 optixSbtRecordPackHeader(optixProgramGroups[t][0], &tmpRec);
140 rgRecDevice.alloc_and_upload(&tmpRec, 1);
141 CUDABuffer hgRecDevice;
142 optixSbtRecordPackHeader(optixProgramGroups[t][1], &tmpRec);
143 hgRecDevice.alloc_and_upload(&tmpRec, 1);
144 CUDABuffer msRecDevice;
145 optixSbtRecordPackHeader(optixProgramGroups[t][2], &tmpRec);
146 msRecDevice.alloc_and_upload(&tmpRec, 1);
147 // fill sbt
148 OptixShaderBindingTable &sbt = sbts[t];
149 sbt.raygenRecord = rgRecDevice.d_pointer();
150 sbt.hitgroupRecordBase = hgRecDevice.d_pointer();
151 sbt.hitgroupRecordCount = 1;
152 sbt.hitgroupRecordStrideInBytes = sizeof(SBTRecordEmpty);
153 sbt.missRecordBase = msRecDevice.d_pointer();
154 sbt.missRecordCount = 1;
155 sbt.missRecordStrideInBytes = sizeof(SBTRecordEmpty);
156 }
157}
158
159} // namespace hmesh
Definition base.cpp:13
OptixShaderBindingTable sbts[SBTType::count]
Definition base.cpp:22
OptixModule optixModule
Definition base.cpp:18
void initOptix()
Definition base.cpp:31
OptixPipeline optixPipelines[SBTType::count]
Definition base.cpp:23
OptixProgramGroup optixProgramGroups[SBTType::count][3]
Definition base.cpp:24
void createOptixContext()
Definition base.cpp:33
OptixPipelineCompileOptions pipelineCompileOptions
Definition base.cpp:19
void createPipelines()
Definition base.cpp:77
void createOptixModule()
Definition base.cpp:43
OptixDeviceContext optixContext
Definition base.cpp:17
CUcontext cuCtx
Definition base.cpp:15
CUstream cuStream
Definition base.cpp:16
void buildSBT()
Definition base.cpp:134
#define OPTIX_CHECK(call)
Definition optix8.h:41
#define PRG_IS
Definition sbtdef.h:23
const std::tuple< std::string, ProgramMask > programInfos[]
Definition sbtdef.h:37
int ProgramMask
Definition sbtdef.h:19
#define PRG_AH
Definition sbtdef.h:25
#define PRG_MS
Definition sbtdef.h:29
SBTRecord< void * > SBTRecordEmpty
Definition sbtdef.h:50
#define PRG_CH
Definition sbtdef.h:27
@ count
Definition sbtdef.h:15
CUdeviceptr d_pointer() const
Definition CUDABuffer.h:29
void alloc_and_upload(const std::vector< T > &vt)
Definition CUDABuffer.h:52