@@ -117,7 +117,6 @@ void Impl_Cast(
117117}
118118} // namespace cuda
119119
120- #if NV_TENSORRT_MAJOR >= 10
121120void * OutputAllocator::reallocateOutputAsync (char const * /* tensorName*/ , void * /* currentMemory*/ , uint64_t size,
122121 uint64_t /* alignment*/ , cudaStream_t /* stream*/ ) noexcept {
123122 // Some memory allocators return nullptr when allocating zero bytes, but TensorRT requires a non-null ptr
@@ -134,25 +133,6 @@ void* OutputAllocator::reallocateOutputAsync(char const* /*tensorName*/, void* /
134133 // if cudaMalloc fails, returns nullptr.
135134 return outputPtr;
136135}
137- #else
138- // Only override this method when TensorRT <= 8.6
139- void * OutputAllocator::reallocateOutput (char const * /* tensorName*/ , void * /* currentMemory*/ , uint64_t size,
140- uint64_t /* alignment*/ ) noexcept {
141- // Some memory allocators return nullptr when allocating zero bytes, but TensorRT requires a non-null ptr
142- // even for empty tensors, so allocate a dummy byte.
143- size = std::max (size, static_cast <uint64_t >(1 ));
144- if (size > allocated_size) {
145- cudaFree (outputPtr);
146- outputPtr = nullptr ;
147- allocated_size = 0 ;
148- if (cudaMalloc (&outputPtr, size) == cudaSuccess) {
149- allocated_size = size;
150- }
151- }
152- // if cudaMalloc fails, returns nullptr.
153- return outputPtr;
154- }
155- #endif
156136
157137void OutputAllocator::notifyShape (char const * /* tensorName*/ , nvinfer1::Dims const & dims) noexcept {
158138 output_shapes.clear ();
@@ -912,6 +892,7 @@ Status BindKernelOutput(Ort::KernelContext& ctx,
912892}
913893
914894NvExecutionProvider::PerThreadContext::PerThreadContext (OrtDevice::DeviceId device_id, bool has_user_compute_stream, cudaStream_t stream) {
895+ // TODO: figure out if PerThreadContext is used at all. If not, just clean it up.
915896 if (has_user_compute_stream) {
916897 CUDA_CALL_THROW (cudaSetDevice (device_id));
917898 (void )(stream);
@@ -1046,8 +1027,16 @@ NvExecutionProvider::NvExecutionProvider(const NvExecutionProviderInfo& info)
10461027 info_ (info),
10471028 device_id_ (info.device_id) {
10481029 InitProviderOrtApi ();
1030+
10491031 // TODO(maximlianm) remove this since we should be able to compile an AOT context file without GPU
1050- CUDA_CALL_THROW (cudaSetDevice (device_id_));
1032+
1033+ if (!info.has_user_compute_stream ) {
1034+ // If the app is passing in a compute stream, it already has initialized cuda and created a context.
1035+ // Calling cudaSetDevice() will set the default context in the current thread
1036+ // which may not be compatible with the stream created by the app.
1037+ CUDA_CALL_THROW (cudaSetDevice (device_id_));
1038+ }
1039+
10511040 cudaDeviceProp prop;
10521041 CUDA_CALL_THROW (cudaGetDeviceProperties (&prop, device_id_));
10531042 compute_capability_ = GetComputeCapacity (prop);
@@ -1068,6 +1057,7 @@ NvExecutionProvider::NvExecutionProvider(const NvExecutionProviderInfo& info)
10681057 max_partition_iterations_ = info.max_partition_iterations ;
10691058 min_subgraph_size_ = info.min_subgraph_size ;
10701059 max_workspace_size_ = info.max_workspace_size ;
1060+ max_shared_mem_size_ = info.max_shared_mem_size ;
10711061 dump_subgraphs_ = info.dump_subgraphs ;
10721062 weight_stripped_engine_enable_ = info.weight_stripped_engine_enable ;
10731063 onnx_model_folder_path_ = info.onnx_model_folder_path ;
@@ -2294,6 +2284,9 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr
22942284 if (max_workspace_size_ > 0 ) {
22952285 trt_config->setMemoryPoolLimit (nvinfer1::MemoryPoolType::kWORKSPACE , max_workspace_size_);
22962286 }
2287+ if (max_shared_mem_size_ > 0 ) {
2288+ trt_config->setMemoryPoolLimit (nvinfer1::MemoryPoolType::kTACTIC_SHARED_MEMORY , max_shared_mem_size_);
2289+ }
22972290 // Only set default compute capabilities if user hasn't explicitly configured them
22982291 constexpr int kDefaultNumComputeCapabilities = 1 ; // Default number of compute capabilities for Turing support
22992292 if (trt_config->getNbComputeCapabilities () == 0 ) {
@@ -2587,7 +2580,7 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr
25872580#pragma warning(push)
25882581#pragma warning(disable : 4996)
25892582#endif
2590- size_t mem_size = trt_engine->getDeviceMemorySize ();
2583+ size_t mem_size = trt_engine->getDeviceMemorySizeV2 ();
25912584#if defined(_MSC_VER)
25922585#pragma warning(pop)
25932586#endif
@@ -2841,7 +2834,7 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr
28412834#pragma warning(push)
28422835#pragma warning(disable : 4996)
28432836#endif
2844- size_t mem_size = trt_engine->getDeviceMemorySize ();
2837+ size_t mem_size = trt_engine->getDeviceMemorySizeV2 ();
28452838#if defined(_MSC_VER)
28462839#pragma warning(pop)
28472840#endif
@@ -2923,7 +2916,7 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr
29232916 if (cuda_graph_enable_ && !IsGraphCaptured (0 )) {
29242917 if (IsGraphCaptureAllowed ()) {
29252918 CaptureEnd (0 );
2926- // CUDA work issued to a capturing stream doesn’ t actually run on the GPU,
2919+ // CUDA work issued to a capturing stream doesn' t actually run on the GPU,
29272920 // so run the captured graph here to actually execute the work.
29282921 ORT_RETURN_IF_ERROR (ReplayGraph (0 ));
29292922 } else {
@@ -2973,7 +2966,7 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const Gra
29732966#pragma warning(push)
29742967#pragma warning(disable : 4996)
29752968#endif
2976- size_t mem_size = trt_engine->getDeviceMemorySize ();
2969+ size_t mem_size = trt_engine->getDeviceMemorySizeV2 ();
29772970#if defined(_MSC_VER)
29782971#pragma warning(pop)
29792972#endif
@@ -3155,7 +3148,7 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const Gra
31553148#pragma warning(push)
31563149#pragma warning(disable : 4996)
31573150#endif
3158- size_t mem_size = trt_engine->getDeviceMemorySize ();
3151+ size_t mem_size = trt_engine->getDeviceMemorySizeV2 ();
31593152#if defined(_MSC_VER)
31603153#pragma warning(pop)
31613154#endif
@@ -3237,7 +3230,7 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const Gra
32373230 if (cuda_graph_enable_ && !IsGraphCaptured (0 )) {
32383231 if (IsGraphCaptureAllowed ()) {
32393232 CaptureEnd (0 );
3240- // CUDA work issued to a capturing stream doesn’ t actually run on the GPU,
3233+ // CUDA work issued to a capturing stream doesn' t actually run on the GPU,
32413234 // so run the captured graph here to actually execute the work.
32423235 ORT_RETURN_IF_ERROR (ReplayGraph (0 ));
32433236 } else {
0 commit comments