Skip to content

Commit e22be64

Browse files
committed
Added initial implementation
1 parent a93266a commit e22be64

File tree

9 files changed

+157
-10
lines changed

9 files changed

+157
-10
lines changed

core/runtime/TRTEngine.cpp

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,8 @@ TRTEngine::TRTEngine(
6161
const Platform& target_platform,
6262
bool hardware_compatible,
6363
bool requires_output_allocator,
64-
const std::string& serialized_metadata)
64+
const std::string& serialized_metadata,
65+
const ResourceAllocationStrategy& resource_allocation_strategy)
6566
: TRTEngine(
6667
"deserialized_trt",
6768
serialized_engine,
@@ -71,7 +72,8 @@ TRTEngine::TRTEngine(
7172
target_platform,
7273
hardware_compatible,
7374
requires_output_allocator,
74-
serialized_metadata) {}
75+
serialized_metadata,
76+
resource_allocation_strategy) {}
7577

7678
TRTEngine::TRTEngine(std::vector<std::string> serialized_info)
7779
: TRTEngine(
@@ -83,7 +85,8 @@ TRTEngine::TRTEngine(std::vector<std::string> serialized_info)
8385
Platform(serialized_info[TARGET_PLATFORM_IDX]),
8486
static_cast<bool>(std::stoi(serialized_info[HW_COMPATIBLE_IDX])),
8587
static_cast<bool>(std::stoi(serialized_info[REQUIRES_OUTPUT_ALLOCATOR_IDX])),
86-
serialized_info[SERIALIZED_METADATA_IDX]) {}
88+
serialized_info[SERIALIZED_METADATA_IDX],
89+
resource_allocation_strategy_from_string(serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX])) {}
8790

8891
TRTEngine::TRTEngine(
8992
const std::string& mod_name,
@@ -94,7 +97,8 @@ TRTEngine::TRTEngine(
9497
const Platform& target_platform,
9598
bool hardware_compatible,
9699
bool requires_output_allocator,
97-
const std::string& serialized_metadata) {
100+
const std::string& serialized_metadata,
101+
const ResourceAllocationStrategy& resource_allocation_strategy) {
98102
TORCHTRT_CHECK(
99103
is_supported_on_current_platform(target_platform),
100104
"This engine was not built to run on this platform (built for: " << target_platform << ", current platform: "
@@ -124,7 +128,12 @@ TRTEngine::TRTEngine(
124128
cuda_engine->setWeightStreamingBudgetV2(budget_bytes);
125129
}
126130

127-
exec_ctx = make_trt(cuda_engine->createExecutionContext());
131+
if (this->resource_allocation_strategy == ResourceAllocationStrategy::kDynamic) {
132+
this->exec_ctx =
133+
make_trt(cuda_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kON_PROFILE_CHANGE));
134+
} else {
135+
this->exec_ctx = make_trt(cuda_engine->createExecutionContext());
136+
}
128137
TORCHTRT_CHECK((exec_ctx.get() != nullptr), "Unable to create TensorRT execution context");
129138

130139
runtime_states.old_cudagraphs = CUDAGRAPHS_MODE;
@@ -436,7 +445,8 @@ FlattenedState TRTEngine::__obj_flatten__() {
436445
std::tuple("hardware_compatible", serialized_info[HW_COMPATIBLE_IDX]),
437446
std::tuple("serialized_metadata", serialized_info[SERIALIZED_METADATA_IDX]),
438447
std::tuple("requires_output_allocator", serialized_info[REQUIRES_OUTPUT_ALLOCATOR_IDX]),
439-
std::tuple("target_platform", serialized_info[TARGET_PLATFORM_IDX]));
448+
std::tuple("target_platform", serialized_info[TARGET_PLATFORM_IDX]),
449+
std::tuple("resource_allocation_strategy", serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX]));
440450
}
441451

442452
std::vector<std::string> TRTEngine::serialize() {
@@ -459,6 +469,8 @@ std::vector<std::string> TRTEngine::serialize() {
459469
serialized_info[REQUIRES_OUTPUT_ALLOCATOR_IDX] = this->requires_output_allocator ? "1" : "0";
460470
serialized_info[SERIALIZED_METADATA_IDX] = this->serialized_metadata;
461471
serialized_info[TARGET_PLATFORM_IDX] = this->target_platform.serialize();
472+
serialized_info[RESOURCE_ALLOCATION_STRATEGY_IDX] =
473+
resource_allocation_strategy_to_string(this->resource_allocation_strategy);
462474

463475
return serialized_info;
464476
}
@@ -467,6 +479,19 @@ void TRTEngine::reset_captured_graph() {
467479
cudagraph.reset();
468480
}
469481

482+
void TRTEngine::set_resource_allocation_strategy(TRTEngine::ResourceAllocationStrategy new_strategy) {
483+
if (new_strategy != this->resource_allocation_strategy) {
484+
this->resource_allocation_strategy = new_strategy;
485+
if (this->resource_allocation_strategy == TRTEngine::ResourceAllocationStrategy::kDynamic) {
486+
std::cout << "Setting resource allocation strategy to dynamic" << std::endl;
487+
this->exec_ctx = make_trt(cuda_engine->createExecutionContext());
488+
} else {
489+
this->exec_ctx = make_trt(
490+
cuda_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kON_PROFILE_CHANGE));
491+
}
492+
}
493+
}
494+
470495
} // namespace runtime
471496
} // namespace core
472497
} // namespace torch_tensorrt

core/runtime/TRTEngine.h

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ using FlattenedState = std::tuple<
2929
std::tuple<std::string, std::string>, // HW compatibility
3030
std::tuple<std::string, std::string>, // requires_output_allocator
3131
std::tuple<std::string, std::string>, // serialized metadata
32-
std::tuple<std::string, std::string>>; // Platform
32+
std::tuple<std::string, std::string>, // Platform
33+
std::tuple<std::string, std::string>>; // Resource Allocation Strategy
3334

3435
struct TorchTRTRuntimeStates {
3536
// Indicates whether CUDAGraphs were enabled in the previous execute_engine
@@ -98,6 +99,8 @@ class DynamicOutputAllocator : public nvinfer1::IOutputAllocator {
9899
};
99100

100101
struct TRTEngine : torch::CustomClassHolder {
102+
// Resource Allocation Strategy
103+
enum ResourceAllocationStrategy { kStatic, kDynamic };
101104
// Each engine needs it's own runtime object
102105
std::shared_ptr<nvinfer1::IRuntime> rt;
103106
std::shared_ptr<nvinfer1::ICudaEngine> cuda_engine;
@@ -128,7 +131,9 @@ struct TRTEngine : torch::CustomClassHolder {
128131
const Platform& target_platform = get_current_platform(),
129132
bool hardware_compatible = false,
130133
bool requires_output_allocator = false,
131-
const std::string& serialized_metadata = "");
134+
const std::string& serialized_metadata = "",
135+
const TRTEngine::ResourceAllocationStrategy& resource_allocation_strategy =
136+
TRTEngine::ResourceAllocationStrategy::kStatic);
132137

133138
TRTEngine(std::vector<std::string> serialized_info);
134139

@@ -141,7 +146,9 @@ struct TRTEngine : torch::CustomClassHolder {
141146
const Platform& target_platform = get_current_platform(),
142147
bool hardware_compatible = false,
143148
bool requires_output_allocator = false,
144-
const std::string& serialized_metadata = "");
149+
const std::string& serialized_metadata = "",
150+
const TRTEngine::ResourceAllocationStrategy& resource_allocation_strategy =
151+
TRTEngine::ResourceAllocationStrategy::kStatic);
145152

146153
TRTEngine& operator=(const TRTEngine& other);
147154
std::string to_str() const;
@@ -200,6 +207,9 @@ struct TRTEngine : torch::CustomClassHolder {
200207
std::string cuda_graph_debug_path;
201208
std::mutex mu;
202209
std::unique_ptr<TRTEngineProfiler> trt_engine_profiler;
210+
ResourceAllocationStrategy resource_allocation_strategy = kStatic;
211+
void set_resource_allocation_strategy(ResourceAllocationStrategy new_strategy);
212+
ResourceAllocationStrategy get_resource_allocation_strategy();
203213
};
204214

205215
} // namespace runtime

core/runtime/execute_engine.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,12 @@ void create_output_allocator(c10::intrusive_ptr<TRTEngine> compiled_engine) {
201201
}
202202

203203
std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intrusive_ptr<TRTEngine> compiled_engine) {
204+
torch::Tensor dynamic_workspace;
205+
if (compiled_engine->resource_allocation_strategy == TRTEngine::ResourceAllocationStrategy::kDynamic) {
206+
dynamic_workspace = torch::empty(compiled_engine->cuda_engine->getDeviceMemorySizeV2(), {torch::kCUDA});
207+
compiled_engine->exec_ctx->setDeviceMemory(dynamic_workspace.data_ptr());
208+
}
209+
204210
auto run_standard_execution = [&]() {
205211
bool cudagraphs_enabled = (CUDAGRAPHS_MODE == SUBGRAPH_CUDAGRAPHS);
206212
bool shape_changed = _validate_shapes(inputs, compiled_engine);

core/runtime/register_jit_hooks.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,21 @@ std::string serialize_bindings(const std::vector<std::string>& bindings) {
2222
return serialized_binding_info;
2323
}
2424

25+
std::string resource_allocation_strategy_to_string(TRTEngine::ResourceAllocationStrategy strategy) {
26+
if (strategy == TRTEngine::ResourceAllocationStrategy::kDynamic) {
27+
return std::string("kDynamic");
28+
} else {
29+
return std::string("kStatic");
30+
}
31+
}
32+
33+
TRTEngine::ResourceAllocationStrategy resource_allocation_strategy_from_string(const std::string& str) {
34+
if (str == "kDynamic")
35+
return TRTEngine::ResourceAllocationStrategy::kDynamic;
36+
else
37+
return TRTEngine::ResourceAllocationStrategy::kStatic;
38+
}
39+
2540
static const std::string sym_table = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; //=
2641
std::string base64_encode(const std::string& in) {
2742
std::string out;
@@ -90,6 +105,13 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
90105
.def("get_engine_layer_info", &TRTEngine::get_engine_layer_info)
91106
.def("infer_outputs", &TRTEngine::infer_outputs)
92107
.def("reset_captured_graph", &TRTEngine::reset_captured_graph)
108+
.def(
109+
"_use_dynamically_allocated_resources",
110+
[](const c10::intrusive_ptr<TRTEngine>& self, bool dynamic) -> void {
111+
self->set_resource_allocation_strategy(
112+
dynamic ? TRTEngine::ResourceAllocationStrategy::kDynamic
113+
: TRTEngine::ResourceAllocationStrategy::kStatic);
114+
})
93115
.def_readwrite("use_pre_allocated_outputs", &TRTEngine::use_pre_allocated_outputs)
94116
.def_readwrite("use_output_allocator_outputs", &TRTEngine::use_output_allocator_outputs)
95117
.def_property(
@@ -135,6 +157,7 @@ TORCH_LIBRARY(tensorrt, m) {
135157
m.def("TARGET_PLATFORM_IDX", []() -> int64_t { return TARGET_PLATFORM_IDX; });
136158
m.def("REQUIRES_OUTPUT_ALLOCATOR_IDX", []() -> int64_t { return REQUIRES_OUTPUT_ALLOCATOR_IDX; });
137159
m.def("SERIALIZATION_LEN", []() -> int64_t { return SERIALIZATION_LEN; });
160+
m.def("RESOURCE_ALLOCATION_STRATEGY_IDX", []() -> int64_t { return RESOURCE_ALLOCATION_STRATEGY_IDX; });
138161
m.def("_platform_linux_x86_64", []() -> std::string {
139162
auto it = get_platform_name_map().find(Platform::PlatformEnum::kLINUX_X86_64);
140163
return it->second;

core/runtime/runtime.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,17 @@ typedef enum {
3838
SERIALIZED_METADATA_IDX,
3939
TARGET_PLATFORM_IDX,
4040
REQUIRES_OUTPUT_ALLOCATOR_IDX,
41+
RESOURCE_ALLOCATION_STRATEGY_IDX,
4142
SERIALIZATION_LEN, // NEVER USED FOR DATA, USED TO DETERMINE LENGTH OF SERIALIZED INFO
4243
} SerializedInfoIndex;
4344

4445
std::string base64_encode(const std::string& in);
4546
std::string base64_decode(const std::string& in);
4647
std::string serialize_bindings(const std::vector<std::string>& bindings);
4748

49+
std::string resource_allocation_strategy_to_string(TRTEngine::ResourceAllocationStrategy strategy);
50+
TRTEngine::ResourceAllocationStrategy resource_allocation_strategy_from_string(const std::string& str);
51+
4852
c10::optional<RTDevice> get_most_compatible_device(
4953
const RTDevice& target_device,
5054
const RTDevice& curr_device = RTDevice(),
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# %%
2+
import numpy as np
3+
import torch
4+
import torch_tensorrt as torch_trt
5+
import torchvision.models as models
6+
from diffusers import DiffusionPipeline
7+
8+
np.random.seed(5)
9+
torch.manual_seed(5)
10+
inputs = [torch.rand((100, 3, 224, 224)).to("cuda")]
11+
12+
settings = {
13+
"ir": "dynamo",
14+
"use_python_runtime": False,
15+
"enabled_precisions": {torch.float32},
16+
"immutable_weights": False,
17+
}
18+
19+
model = models.resnet152(pretrained=True).eval().to("cuda")
20+
compiled_module = torch_trt.compile(model, inputs=inputs, **settings)
21+
print((torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]) / 1024**3)
22+
compiled_module(*inputs)
23+
24+
breakpoint()
25+
with torch_trt.dynamo.runtime.ResourceAllocatorContext(compiled_module):
26+
print(
27+
"Memory used (GB):",
28+
(torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]) / 1024**3,
29+
)
30+
breakpoint()
31+
compiled_module(*inputs)
32+
print(
33+
"Memory used (GB):",
34+
(torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]) / 1024**3,
35+
)
36+
breakpoint()
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from typing import Any
2+
3+
import torch
4+
5+
6+
class ResourceAllocatorContext(torch.nn.Module): # type: ignore[misc]
7+
"""
8+
ResourceAllocatorContext is a context manager module that temporarily enables dynamic resource allocation
9+
for all TRT submodules of the given compiled_module. When entering the context,
10+
it sets these submodules to use dynamically allocated resources. Upon exiting, it restores them to their
11+
original (static) resource allocation mode.
12+
"""
13+
14+
def __init__(
15+
self,
16+
compiled_module: torch.nn.Module,
17+
) -> None:
18+
super(ResourceAllocatorContext, self).__init__()
19+
self.compiled_module = compiled_module
20+
21+
def __enter__(self) -> None:
22+
print("Entering resource allocator context")
23+
for name, submodule in self.compiled_module.named_modules():
24+
if "_run_on_acc" in name:
25+
submodule.use_dynamically_allocated_resources(dynamic=True)
26+
27+
def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> None:
28+
for name, submodule in self.compiled_module.named_modules():
29+
if "_run_on_acc" in name:
30+
submodule.use_dynamically_allocated_resources(dynamic=False)

py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,10 @@
5050
REQUIRES_OUTPUT_ALLOCATOR_IDX = (
5151
torch.ops.tensorrt.REQUIRES_OUTPUT_ALLOCATOR_IDX()
5252
) # 9
53-
SERIALIZATION_LEN = torch.ops.tensorrt.SERIALIZATION_LEN() # 10
53+
RESOURCE_ALLOCATION_STRATEGY_IDX = (
54+
torch.ops.tensorrt.RESOURCE_ALLOCATION_STRATEGY_IDX()
55+
) # 10
56+
SERIALIZATION_LEN = torch.ops.tensorrt.SERIALIZATION_LEN() # 11
5457

5558

5659
@for_all_methods(needs_torch_tensorrt_runtime)
@@ -139,6 +142,7 @@ def __init__(
139142
self.serialized_engine = serialized_engine
140143
self.engine = None
141144
self.requires_output_allocator = requires_output_allocator
145+
self.resource_allocation_strategy = 0 # Default to static allocation TODO: Make this configurable with the context manager
142146

143147
if (
144148
serialized_engine
@@ -184,6 +188,9 @@ def _pack_engine_info(self) -> List[str | bytes]:
184188
engine_info[REQUIRES_OUTPUT_ALLOCATOR_IDX] = str(
185189
int(self.requires_output_allocator)
186190
)
191+
engine_info[RESOURCE_ALLOCATION_STRATEGY_IDX] = str(
192+
int(self.resource_allocation_strategy)
193+
)
187194

188195
return engine_info
189196

@@ -212,6 +219,9 @@ def set_device_memory_budget(self, budget_bytes: int) -> int:
212219
def _reset_captured_graph(self) -> None:
213220
self.engine.reset_captured_graph()
214221

222+
def use_dynamically_allocated_resources(self, dynamic: bool = False) -> None:
223+
self.engine._use_dynamically_allocated_resources(dynamic)
224+
215225
def setup_engine(self) -> None:
216226
"""
217227
Setup engine for a module which has deferred engine setup.

py/torch_tensorrt/dynamo/runtime/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
from torch_tensorrt.dynamo.runtime._PythonTorchTensorRTModule import ( # noqa: F401
33
PythonTorchTensorRTModule,
44
)
5+
from torch_tensorrt.dynamo.runtime._ResourceAllocator import ( # noqa: F401
6+
ResourceAllocatorContext,
7+
)
58
from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import ( # noqa: F401
69
TorchTensorRTModule,
710
)

0 commit comments

Comments
 (0)