Skip to content

Commit cc0de0d

Browse files
authored
[Build] Propagate build option for CUDA minimal to TRT (microsoft#20695)
### Description Extend cuda minimal option to TRT provider, as with TRT 10 no linking to cuDNN is required anymore . Besides that with the new engine dump feature it is also possible to embed an engine in to an ONNX and not ship a builder lib. In addition to that this has roughly the same deserialization time/session setup time that using TRT standalone has. ### Motivation and Context ``` exe_builder_lib\onnxruntime_perf_test.exe -I -e tensorrt -r 5 -i 'trt_engine_cache_enable|1 trt_timing_cache_enable|1 trt_dump_ep_context_model|1 trt_weightless_engine_enable|1' model.onnx exe_no_builder_lib\onnxruntime_perf_test.exe -I -e tensorrt -r 5 -i 'trt_engine_cache_enable|1 trt_timing_cache_enable|1 trt_dump_ep_context_model|1 trt_weightless_engine_enable|1' model_ctx.onnx ```
1 parent 307b34a commit cc0de0d

File tree

7 files changed

+50
-20
lines changed

7 files changed

+50
-20
lines changed

cmake/onnxruntime_providers_cuda.cmake

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,19 +33,21 @@
3333
)
3434

3535

36-
if (onnxruntime_CUDA_MINIMAL)
37-
set(onnxruntime_providers_cuda_shared_srcs "")
38-
else()
36+
if (NOT onnxruntime_CUDA_MINIMAL)
3937
file(GLOB_RECURSE onnxruntime_providers_cuda_cu_srcs CONFIGURE_DEPENDS
4038
"${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cu"
4139
"${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cuh"
4240
)
41+
else()
42+
set(onnxruntime_providers_cuda_cu_srcs
43+
"${ONNXRUNTIME_ROOT}/core/providers/cuda/math/unary_elementwise_ops_impl.cu"
44+
)
4345
endif()
4446
source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_cuda_cc_srcs} ${onnxruntime_providers_cuda_shared_srcs} ${onnxruntime_providers_cuda_cu_srcs})
4547
set(onnxruntime_providers_cuda_src ${onnxruntime_providers_cuda_cc_srcs} ${onnxruntime_providers_cuda_shared_srcs} ${onnxruntime_providers_cuda_cu_srcs})
4648

4749
# disable contrib ops conditionally
48-
if(NOT onnxruntime_DISABLE_CONTRIB_OPS)
50+
if(NOT onnxruntime_DISABLE_CONTRIB_OPS AND NOT onnxruntime_CUDA_MINIMAL)
4951
if (NOT onnxruntime_ENABLE_ATEN)
5052
list(REMOVE_ITEM onnxruntime_cuda_contrib_ops_cc_srcs
5153
"${ONNXRUNTIME_ROOT}/contrib_ops/cuda/aten_ops/aten_op.cc"

cmake/onnxruntime_providers_tensorrt.cmake

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# Copyright (c) Microsoft Corporation. All rights reserved.
22
# Licensed under the MIT License.
3-
3+
if(onnxruntime_DISABLE_CONTRIB_OPS)
4+
message( FATAL_ERROR "To compile TensorRT execution provider contrib ops have to be enabled to dump an engine using com.microsoft:EPContext node." )
5+
endif()
46
add_definitions(-DUSE_TENSORRT=1)
57
if (onnxruntime_TENSORRT_PLACEHOLDER_BUILDER)
68
add_definitions(-DORT_TENSORRT_PLACEHOLDER_BUILDER)
@@ -154,8 +156,11 @@
154156
# See https://github.com/onnx/onnx-tensorrt/blob/8af13d1b106f58df1e98945a5e7c851ddb5f0791/CMakeLists.txt#L121
155157
# However, starting from TRT 10 GA, nvonnxparser_static doesn't link against tensorrt libraries.
156158
# Therefore, the above code finds ${TENSORRT_LIBRARY_INFER} and ${TENSORRT_LIBRARY_INFER_PLUGIN}.
157-
set(trt_link_libs cudnn cublas ${CMAKE_DL_LIBS} ${TENSORRT_LIBRARY})
158-
159+
if(onnxruntime_CUDA_MINIMAL)
160+
set(trt_link_libs ${CMAKE_DL_LIBS} ${TENSORRT_LIBRARY})
161+
else()
162+
set(trt_link_libs cudnn cublas ${CMAKE_DL_LIBS} ${TENSORRT_LIBRARY})
163+
endif()
159164
file(GLOB_RECURSE onnxruntime_providers_tensorrt_cc_srcs CONFIGURE_DEPENDS
160165
"${ONNXRUNTIME_ROOT}/core/providers/tensorrt/*.h"
161166
"${ONNXRUNTIME_ROOT}/core/providers/tensorrt/*.cc"
@@ -190,6 +195,9 @@
190195
if (WIN32)
191196
target_compile_options(onnxruntime_providers_tensorrt INTERFACE /wd4456)
192197
endif()
198+
if(onnxruntime_CUDA_MINIMAL)
199+
target_compile_definitions(onnxruntime_providers_tensorrt PRIVATE USE_CUDA_MINIMAL=1)
200+
endif()
193201

194202
# Needed for the provider interface, as it includes training headers when training is enabled
195203
if (onnxruntime_ENABLE_TRAINING_OPS)

onnxruntime/core/providers/cpu/cpu_provider_shared.cc

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,11 @@ struct ProviderHostCPUImpl : ProviderHostCPU {
192192
Status EinsumTypedComputeProcessor__Run(EinsumTypedComputeProcessor<float>* p) override { return p->Run(); }
193193
Status EinsumTypedComputeProcessor__Run(EinsumTypedComputeProcessor<double>* p) override { return p->Run(); }
194194
Status EinsumTypedComputeProcessor__Run(EinsumTypedComputeProcessor<MLFloat16>* p) override { return p->Run(); }
195+
void UpsampleBase__AdjustOutputSizeAsPolicy(const UpsampleBase* p, TensorShapeVector& output_dims,
196+
gsl::span<const int64_t> input_dims,
197+
InlinedVector<float>& scales) const override {
198+
p->AdjustOutputSizeAsPolicy(output_dims, input_dims, scales);
199+
}
195200

196201
#ifndef DISABLE_CONTRIB_OPS
197202
Status embed_layer_norm__CheckInputs(const OpKernelContext* context, bool quantizedVersion) override {
@@ -294,12 +299,6 @@ struct ProviderHostCPUImpl : ProviderHostCPU {
294299
Status Sampling__Compute(const contrib::transformers::Sampling* p, OpKernelContext* ctx) override { return p->contrib::transformers::Sampling::Compute(ctx); }
295300
Status Sampling__SetupSubgraphExecutionInfo(contrib::transformers::Sampling* p, const SessionState& session_state, const std::string& attribute_name, const SessionState& subgraph_session_state) override { return p->contrib::transformers::Sampling::SetupSubgraphExecutionInfo(session_state, attribute_name, subgraph_session_state); }
296301

297-
void UpsampleBase__AdjustOutputSizeAsPolicy(const UpsampleBase* p, TensorShapeVector& output_dims,
298-
gsl::span<const int64_t> input_dims,
299-
InlinedVector<float>& scales) const override {
300-
p->AdjustOutputSizeAsPolicy(output_dims, input_dims, scales);
301-
}
302-
303302
#ifdef ENABLE_ATEN
304303
Status ATen__Compute(const contrib::ATen* p, OpKernelContext* p_ctx) override { return p->ATen::Compute(p_ctx); }
305304
#endif

onnxruntime/core/providers/cpu/cpu_provider_shared.h

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,9 @@ struct ProviderHostCPU {
141141
virtual Status Scan__Compute(const Scan<9>* p, OpKernelContext* ctx) = 0;
142142
virtual Status Scan__SetupSubgraphExecutionInfo(Scan<8>* p, const SessionState& session_state, const std::string& attribute_name, const SessionState& subgraph_session_state) = 0;
143143
virtual Status Scan__SetupSubgraphExecutionInfo(Scan<9>* p, const SessionState& session_state, const std::string& attribute_name, const SessionState& subgraph_session_state) = 0;
144-
144+
virtual void UpsampleBase__AdjustOutputSizeAsPolicy(const UpsampleBase* p, TensorShapeVector& output_dims,
145+
gsl::span<const int64_t> input_dims,
146+
InlinedVector<float>& scales) const = 0;
145147
#ifndef DISABLE_CONTRIB_OPS
146148
virtual Status embed_layer_norm__CheckInputs(const OpKernelContext* context, bool quantizedVersion) = 0;
147149
virtual Status bias_gelu_helper__CheckInputs(const OpKernelContext* context) = 0;
@@ -203,10 +205,6 @@ struct ProviderHostCPU {
203205
virtual Status Sampling__Compute(const contrib::transformers::Sampling* p, OpKernelContext* ctx) = 0;
204206
virtual Status Sampling__SetupSubgraphExecutionInfo(contrib::transformers::Sampling* p, const SessionState& session_state, const std::string& attribute_name, const SessionState& subgraph_session_state) = 0;
205207

206-
virtual void UpsampleBase__AdjustOutputSizeAsPolicy(const UpsampleBase* p, TensorShapeVector& output_dims,
207-
gsl::span<const int64_t> input_dims,
208-
InlinedVector<float>& scales) const = 0;
209-
210208
#ifdef ENABLE_ATEN
211209
virtual Status ATen__Compute(const contrib::ATen* p, OpKernelContext* p_ctx) = 0;
212210
#endif

onnxruntime/core/providers/cuda/cuda_stream_handle.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,9 @@ CudaStream::CudaStream(cudaStream_t stream,
8181
cudnn_handle_ = external_cudnn_handle;
8282
CUDNN_CALL_THROW(cudnnSetStream(cudnn_handle_, stream));
8383
}
84+
#else
85+
(void)(external_cudnn_handle);
86+
(void)(external_cublas_handle);
8487
#endif
8588
}
8689

onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,7 @@ void CudaCall<cudaError, true>(cudaError retCode, const char* exprString, const
287287
return g_host->CudaCall_true(retCode, exprString, libName, successCode, msg, file, line);
288288
}
289289

290+
#ifndef USE_CUDA_MINIMAL
290291
template <>
291292
Status CudaCall<cublasStatus_t, false>(cublasStatus_t retCode, const char* exprString, const char* libName, cublasStatus_t successCode, const char* msg, const char* file, const int line) {
292293
return g_host->CudaCall_false(retCode, exprString, libName, successCode, msg, file, line);
@@ -306,6 +307,7 @@ template <>
306307
void CudaCall<cudnnStatus_t, true>(cudnnStatus_t retCode, const char* exprString, const char* libName, cudnnStatus_t successCode, const char* msg, const char* file, const int line) {
307308
return g_host->CudaCall_true(retCode, exprString, libName, successCode, msg, file, line);
308309
}
310+
#endif
309311

310312
#if NV_TENSORRT_MAJOR >= 10
311313
void* OutputAllocator::reallocateOutputAsync(char const* /*tensorName*/, void* /*currentMemory*/, uint64_t size,
@@ -1119,20 +1121,26 @@ Status BindKernelOutput(Ort::KernelContext& ctx,
11191121
TensorrtExecutionProvider::PerThreadContext::PerThreadContext(OrtDevice::DeviceId device_id, bool has_user_compute_stream, cudaStream_t stream) {
11201122
if (has_user_compute_stream) {
11211123
CUDA_CALL_THROW(cudaSetDevice(device_id));
1124+
#ifndef USE_CUDA_MINIMAL
11221125
ORT_IGNORE_RETURN_VALUE(CUBLAS_CALL(cublasCreate(&external_cublas_handle_)));
11231126
ORT_IGNORE_RETURN_VALUE(CUBLAS_CALL(cublasSetStream(external_cublas_handle_, stream)));
11241127
ORT_IGNORE_RETURN_VALUE(CUDNN_CALL(cudnnCreate(&external_cudnn_handle_)));
11251128
ORT_IGNORE_RETURN_VALUE(CUDNN_CALL(cudnnSetStream(external_cudnn_handle_, stream)));
1129+
#else
1130+
(void)(stream);
1131+
#endif
11261132
}
11271133
}
11281134

11291135
TensorrtExecutionProvider::PerThreadContext::~PerThreadContext() {
1136+
#ifndef USE_CUDA_MINIMAL
11301137
if (external_cublas_handle_ != nullptr) {
11311138
ORT_IGNORE_RETURN_VALUE(CUBLAS_CALL(cublasDestroy(external_cublas_handle_)));
11321139
}
11331140
if (external_cudnn_handle_ != nullptr) {
11341141
ORT_IGNORE_RETURN_VALUE(CUDNN_CALL(cudnnDestroy(external_cudnn_handle_)));
11351142
}
1143+
#endif
11361144
trt_context_map_.clear();
11371145
}
11381146

@@ -1268,10 +1276,12 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
12681276
if (info.has_user_compute_stream) {
12691277
external_stream_ = true;
12701278
stream_ = static_cast<cudaStream_t>(info.user_compute_stream);
1279+
#ifndef USE_CUDA_MINIMAL
12711280
ORT_IGNORE_RETURN_VALUE(CUBLAS_CALL(cublasCreate(&external_cublas_handle_)));
12721281
ORT_IGNORE_RETURN_VALUE(CUBLAS_CALL(cublasSetStream(external_cublas_handle_, stream_)));
12731282
ORT_IGNORE_RETURN_VALUE(CUDNN_CALL(cudnnCreate(&external_cudnn_handle_)));
12741283
ORT_IGNORE_RETURN_VALUE(CUDNN_CALL(cudnnSetStream(external_cudnn_handle_, stream_)));
1284+
#endif
12751285
}
12761286

12771287
std::string profile_min_shapes, profile_max_shapes, profile_opt_shapes;
@@ -1442,6 +1452,10 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
14421452
if (!ep_context_embed_mode_env.empty()) {
14431453
ep_context_embed_mode_ = std::stoi(ep_context_embed_mode_env);
14441454
}
1455+
// incase the EP context is dumped the engine cache has to be enabled
1456+
if (dump_ep_context_model_ && ep_context_embed_mode_ == 0) {
1457+
engine_cache_enable_ = true;
1458+
}
14451459

14461460
enable_engine_cache_for_ep_context_model();
14471461

@@ -1737,8 +1751,10 @@ TensorrtExecutionProvider::~TensorrtExecutionProvider() {
17371751
}
17381752

17391753
if (external_stream_) {
1754+
#ifndef USE_CUDA_MINIMAL
17401755
ORT_IGNORE_RETURN_VALUE(CUBLAS_CALL(cublasDestroy(external_cublas_handle_)));
17411756
ORT_IGNORE_RETURN_VALUE(CUDNN_CALL(cudnnDestroy(external_cudnn_handle_)));
1757+
#endif
17421758
}
17431759

17441760
if (!external_stream_ && stream_) {

onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,13 @@
33

44
#pragma once
55
#include <ctime>
6+
#ifndef USE_CUDA_MINIMAL
67
#include <cudnn.h>
7-
#include <cublas_v2.h>
8-
8+
#else
9+
typedef void* cudnnHandle_t;
10+
typedef void* cublasHandle_t;
11+
typedef void* cudnnStatus_t;
12+
#endif
913
#include "core/providers/tensorrt/nv_includes.h"
1014

1115
#include "core/platform/ort_mutex.h"

0 commit comments

Comments
 (0)