Skip to content

Commit 16ae99e

Browse files
umangb-09gedoensmaxgaugarg-nv
authored
Add cuda graph implementation for NV TRT RTX EP (microsoft#25787)
### Description This change adds CUDA Graph support to the NV TensorRT RTX Execution Provider (EP). ### Motivation and Context Integrating CUDA Graphs into the NV TRT RTX EP provides: Lower latency by minimizing per-kernel launch overhead. Better throughput for repeated inference runs. Improved efficiency on GPUs with high kernel launches overhead sensitivity. --------- Co-authored-by: Maximilian Mueller <maximilianm@nvidia.com> Co-authored-by: Gaurav Garg <gaugarg@nvidia.com>
1 parent 7e3174b commit 16ae99e

File tree

5 files changed

+267
-137
lines changed

5 files changed

+267
-137
lines changed

include/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_options.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ constexpr const char* kDetailedBuildLog = "nv_detailed_build_log";
3131
constexpr const char* kProfilesMinShapes = "nv_profile_min_shapes";
3232
constexpr const char* kProfilesMaxShapes = "nv_profile_max_shapes";
3333
constexpr const char* kProfilesOptShapes = "nv_profile_opt_shapes";
34-
constexpr const char* kCudaGraphEnable = "nv_cuda_graph_enable";
34+
constexpr const char* kCudaGraphEnable = "enable_cuda_graph";
3535
constexpr const char* kMultiProfileEnable = "nv_multi_profile_enable";
3636
constexpr const char* kUseExternalDataInitializer = "nv_use_external_data_initializer";
3737

onnxruntime/core/providers/cuda/cuda_graph.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ void CUDAGraphManager::CaptureEnd(CudaGraphAnnotation_t cuda_graph_annotation_id
7272
cuda_graph_set_.Put(cuda_graph_annotation_id, graph_exec);
7373
}
7474

75-
Status CUDAGraphManager::Replay(CudaGraphAnnotation_t cuda_graph_annotation_id) {
75+
Status CUDAGraphManager::Replay(CudaGraphAnnotation_t cuda_graph_annotation_id, bool sync_status_flag) {
7676
// Although this function is not thread safe, the lock is not needed here because
7777
// CUDA EP maintains a separate cuda graph per thread
7878
LOGS_DEFAULT(INFO) << "Replaying CUDA graph on stream " << stream_ << " with cuda_graph_annotation_id "
@@ -81,7 +81,9 @@ Status CUDAGraphManager::Replay(CudaGraphAnnotation_t cuda_graph_annotation_id)
8181
cudaGraphExec_t graph_exec = cuda_graph_set_.Get(cuda_graph_annotation_id);
8282
CUDA_RETURN_IF_ERROR(cudaGraphLaunch(graph_exec, stream_));
8383

84-
CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream_));
84+
if (sync_status_flag) {
85+
CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream_));
86+
}
8587
return Status::OK();
8688
}
8789

onnxruntime/core/providers/cuda/cuda_graph.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ struct CUDAGraphManager {
3838
void SetStream(cudaStream_t stream);
3939
void CaptureBegin(CudaGraphAnnotation_t cuda_graph_annotation_id);
4040
void CaptureEnd(CudaGraphAnnotation_t cuda_graph_annotation_id);
41-
Status Replay(CudaGraphAnnotation_t cuda_graph_annotation_id);
41+
Status Replay(CudaGraphAnnotation_t cuda_graph_annotation_id, bool sync_status_flag = true);
4242

4343
void Reset();
4444

0 commit comments

Comments
 (0)