@@ -61,7 +61,8 @@ TRTEngine::TRTEngine(
6161 const Platform& target_platform,
6262 bool hardware_compatible,
6363 bool requires_output_allocator,
64- const std::string& serialized_metadata)
64+ const std::string& serialized_metadata,
65+ const ResourceAllocationStrategy& resource_allocation_strategy)
6566 : TRTEngine(
6667 " deserialized_trt" ,
6768 serialized_engine,
@@ -71,7 +72,8 @@ TRTEngine::TRTEngine(
7172 target_platform,
7273 hardware_compatible,
7374 requires_output_allocator,
74- serialized_metadata) {}
75+ serialized_metadata,
76+ resource_allocation_strategy) {}
7577
7678TRTEngine::TRTEngine (std::vector<std::string> serialized_info)
7779 : TRTEngine(
@@ -83,7 +85,8 @@ TRTEngine::TRTEngine(std::vector<std::string> serialized_info)
8385 Platform(serialized_info[TARGET_PLATFORM_IDX]),
8486 static_cast<bool>(std::stoi(serialized_info[HW_COMPATIBLE_IDX])),
8587 static_cast<bool>(std::stoi(serialized_info[REQUIRES_OUTPUT_ALLOCATOR_IDX])),
86- serialized_info[SERIALIZED_METADATA_IDX]) {}
88+ serialized_info[SERIALIZED_METADATA_IDX],
89+ resource_allocation_strategy_from_string(serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX])) {}
8790
8891TRTEngine::TRTEngine (
8992 const std::string& mod_name,
@@ -94,7 +97,8 @@ TRTEngine::TRTEngine(
9497 const Platform& target_platform,
9598 bool hardware_compatible,
9699 bool requires_output_allocator,
97- const std::string& serialized_metadata) {
100+ const std::string& serialized_metadata,
101+ const ResourceAllocationStrategy& resource_allocation_strategy) {
98102 TORCHTRT_CHECK (
99103 is_supported_on_current_platform (target_platform),
100104 " This engine was not built to run on this platform (built for: " << target_platform << " , current platform: "
@@ -124,7 +128,12 @@ TRTEngine::TRTEngine(
124128 cuda_engine->setWeightStreamingBudgetV2 (budget_bytes);
125129 }
126130
127- exec_ctx = make_trt (cuda_engine->createExecutionContext ());
131+ if (this ->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic ) {
132+ this ->exec_ctx =
133+ make_trt (cuda_engine->createExecutionContext (nvinfer1::ExecutionContextAllocationStrategy::kON_PROFILE_CHANGE ));
134+ } else {
135+ this ->exec_ctx = make_trt (cuda_engine->createExecutionContext ());
136+ }
128137 TORCHTRT_CHECK ((exec_ctx.get () != nullptr ), " Unable to create TensorRT execution context" );
129138
130139 runtime_states.old_cudagraphs = CUDAGRAPHS_MODE;
@@ -436,7 +445,8 @@ FlattenedState TRTEngine::__obj_flatten__() {
436445 std::tuple (" hardware_compatible" , serialized_info[HW_COMPATIBLE_IDX]),
437446 std::tuple (" serialized_metadata" , serialized_info[SERIALIZED_METADATA_IDX]),
438447 std::tuple (" requires_output_allocator" , serialized_info[REQUIRES_OUTPUT_ALLOCATOR_IDX]),
439- std::tuple (" target_platform" , serialized_info[TARGET_PLATFORM_IDX]));
448+ std::tuple (" target_platform" , serialized_info[TARGET_PLATFORM_IDX]),
449+ std::tuple (" resource_allocation_strategy" , serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX]));
440450}
441451
442452std::vector<std::string> TRTEngine::serialize () {
@@ -459,6 +469,8 @@ std::vector<std::string> TRTEngine::serialize() {
459469 serialized_info[REQUIRES_OUTPUT_ALLOCATOR_IDX] = this ->requires_output_allocator ? " 1" : " 0" ;
460470 serialized_info[SERIALIZED_METADATA_IDX] = this ->serialized_metadata ;
461471 serialized_info[TARGET_PLATFORM_IDX] = this ->target_platform .serialize ();
472+ serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX] =
473+ resource_allocation_strategy_to_string (this ->resource_allocation_strategy );
462474
463475 return serialized_info;
464476}
@@ -467,6 +479,19 @@ void TRTEngine::reset_captured_graph() {
467479 cudagraph.reset ();
468480}
469481
482+ void TRTEngine::set_resource_allocation_strategy (TRTEngine::ResourceAllocationStrategy new_strategy) {
483+ if (new_strategy != this ->resource_allocation_strategy ) {
484+ this ->resource_allocation_strategy = new_strategy;
485+ if (this ->resource_allocation_strategy == TRTEngine::ResourceAllocationStrategy::kDynamic ) {
486+ std::cout << " Setting resource allocation strategy to dynamic" << std::endl;
487+ this ->exec_ctx = make_trt (cuda_engine->createExecutionContext ());
488+ } else {
489+ this ->exec_ctx = make_trt (
490+ cuda_engine->createExecutionContext (nvinfer1::ExecutionContextAllocationStrategy::kON_PROFILE_CHANGE ));
491+ }
492+ }
493+ }
494+
470495} // namespace runtime
471496} // namespace core
472497} // namespace torch_tensorrt
0 commit comments