@@ -94,6 +94,7 @@ bool _validate_shapes(std::vector<at::Tensor> inputs, c10::intrusive_ptr<TRTEngi
9494void setup_input_tensors (
9595 std::vector<at::Tensor> inputs,
9696 c10::intrusive_ptr<TRTEngine> compiled_engine,
97+ bool cudagraphs_enabled,
9798 bool need_cudagraphs_record) {
9899 // this is a buffer to store shape tensor input addresses throughout the runtime scope
99100 std::list<std::vector<int64_t >> inputShapeTensorValues;
@@ -127,7 +128,7 @@ void setup_input_tensors(
127128 compiled_engine->exec_ctx ->setTensorAddress (name.c_str (), inputShapeTensorValues.back ().data ()),
128129 " Error while setting the tensor address for shape inputs" );
129130
130- if (CUDAGRAPHS_MODE ) {
131+ if (cudagraphs_enabled ) {
131132 // @peri044 I dont know if this makes sense since they are supposed to be GPU buffers
132133 compiled_engine->input_buffers [i] = input_cpu;
133134 }
@@ -147,7 +148,7 @@ void setup_input_tensors(
147148 TORCHTRT_CHECK (
148149 compiled_engine->exec_ctx ->setInputShape (name.c_str (), dims), " Error while setting the input shape" );
149150
150- if (CUDAGRAPHS_MODE ) {
151+ if (cudagraphs_enabled ) {
151152 // If using CUDAGraphs copy formatted input to the corresponding persistent input buffer
152153 compiled_engine->input_buffers [i].copy_ (formatted_inputs.back (), true );
153154 TORCHTRT_CHECK (
@@ -201,17 +202,17 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
201202 LOG_INFO (" " << log_info);
202203 compiled_engine->cudagraph .enable_debug_mode ();
203204 }
204-
205+ bool cudagraphs_enabled = (CUDAGRAPHS_MODE == SUBGRAPH_CUDAGRAPHS);
205206 bool shape_changed = _validate_shapes (inputs, compiled_engine);
206207
207208 // Whether cudagraphs needs to record the graph on this pass
208209 auto result = compiled_engine->runtime_states .set_runtime_states (
209- CUDAGRAPHS_MODE , compiled_engine->use_pre_allocated_outputs , shape_changed);
210+ cudagraphs_enabled , compiled_engine->use_pre_allocated_outputs , shape_changed);
210211
211212 bool need_cudagraphs_record = std::get<0 >(result);
212213 bool can_use_pre_allocated_outputs = std::get<1 >(result);
213214
214- if (!CUDAGRAPHS_MODE || shape_changed) {
215+ if (!cudagraphs_enabled || shape_changed) {
215216 compiled_engine->cudagraph .reset ();
216217 }
217218
@@ -273,8 +274,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
273274 std::make_unique<torch::autograd::profiler::RecordProfile>(compiled_engine->input_profile_path );
274275 }
275276
276- setup_input_tensors (inputs, compiled_engine, need_cudagraphs_record);
277-
277+ setup_input_tensors (inputs, compiled_engine, cudagraphs_enabled, need_cudagraphs_record);
278278 // Check if input shapes can be inferred.
279279 int32_t const io_size{compiled_engine->cuda_engine ->getNbIOTensors ()};
280280 std::vector<char const *> names (io_size);
@@ -306,7 +306,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
306306 compiled_engine->output_buffers [pyt_idx] = std::move (outputs[pyt_idx].clone ());
307307 }
308308
309- if (CUDAGRAPHS_MODE ) {
309+ if (cudagraphs_enabled ) {
310310 TORCHTRT_CHECK (
311311 compiled_engine->exec_ctx ->setTensorAddress (
312312 name.c_str (), compiled_engine->output_buffers [pyt_idx].data_ptr ()),
@@ -346,7 +346,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
346346 caller_exec_complete.record (compiled_engine->caller_stream );
347347 caller_exec_complete.block (compiled_engine->engine_stream );
348348
349- if (!CUDAGRAPHS_MODE ) {
349+ if (!cudagraphs_enabled ) {
350350 // Direct execution uses the caller buffers directly
351351 compiled_engine->exec_ctx ->enqueueV3 (compiled_engine->engine_stream );
352352 } else {
@@ -377,7 +377,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
377377 trt_exec_complete.record (compiled_engine->engine_stream );
378378 trt_exec_complete.block (compiled_engine->caller_stream );
379379
380- if (CUDAGRAPHS_MODE ) {
380+ if (cudagraphs_enabled ) {
381381 // If in CUDAGraph mode, results need to be copied to the result buffers (on caller stream)
382382 for (size_t o = 0 ; o < compiled_engine->output_buffers .size (); o++) {
383383 outputs[o].copy_ (compiled_engine->output_buffers [o], false );
0 commit comments